java libsvm回归_java学习--Libsvm java版代码注释及详解(一)

由于工作中要用到SVR算法,项目组的系统是用java开发的,因此,为了能与项目组同步,算法需要用java来实现,还好台湾大学的林智仁教授推出了Libsvm的源代码,包括java、c++等语言的源代码,在此表示感谢!因此,算法的主体部分不用自己开发了,在源代码的基础上做一些修改就能够应用到自己的项目中了,开源真好!受益了无数人。。。为了弘扬开源的精神,开博记录学习Libsvm java版源代码的过程。下面正式开始!先从SVR回归算法的代码开始,然后逐步扩展到分类算法。希望自己能够坚持下去。加油。

---------------------我是华丽的分割线-----------高手的小jj,我割割割--嘻嘻-------------------

一、初识Libsvm

Libsvm java版本的源代码很容易下载到,为了使算法能够运行,要先把源代码复制到myeclipse中(貌似是废话),在这里,提供一个链接,里面有很好的说明,按照文档中的说明就能够将Libsvm运行起来。

按照上述链接中的方式,自己新建一个main函数,来调用Libsvm算法的源代码,代码如下:

public static void main(String[] args) throws IOException {

String []arg ={ "trainfile\\train1.txt", //存放SVM训练模型用的数据的路径

"trainfile\\model_r.txt"};  //存放SVM通过训练数据训练出来的模型的路径

String []parg={"trainfile\\test2.txt",   //这个是存放测试数据

"trainfile\\model_r.txt",  //调用的是训练以后的模型

"trainfile\\out_r.txt"};  //生成的结果的文件的路径

System.out.println("........SVM运行开始..........");

//创建一个训练对象

svm_train t = new svm_train();

//创建一个预测或者分类的对象

svm_predict p= new svm_predict();

t.main(arg);   //调用

p.main(parg);  //调用

}

注意:

1. 该主函数是为了调用Libsvm源代码的,创建了两个字符数组:arg[]和parg[]。

其中arg[]数组存放了两个字符串,"trainfile\\train1.txt"和"trainfile\\model_r.txt",这两个字符串传入到Libsvm中的svm_train();

在trainfile\\train1.txt这个文件里面存储的是训练数据,训练数据的格式如下:

Y     index1:特征1   index2:特征2    ....

 1:特征1       2:特征2       3:特征3 ...

  1:特征1       2:特征2       3:特征3 ...

其中如果是分类问题,lable为类标签。如果是回归问题,lable为具体的实数。为了生成这种格式的数据,可以采用FormatDataLibsvm.exl这个excel文件生成,网上可以下的到。当然,也可以自己写java代码,来生成这种格式的数据。代码随后奉上。而trainfile\\model_r.txt这个文件的作用是存储利用SVM算法训练好的模型。

而parg[]这个数组存放了3个字符串:trainfile\\test2.txt---用于存储测试样本的文件,数据格式与训练样本的格式一样、trainfile\\model_r.txt---前面生成的训练好的模型,在用新样本进行预测时直接使用前文训练好的模型即可、trainfile\\model_r.txt---用于存储模型的预测值,分类问题的话存储的是预测样本每条样本所属的类别,而回归问题的话,存储的是每条样本所对应的预测值。

2. svm_train()与svm_predict()。这两个函数是Libsvm程序包中的源代码,由此即进入到了Libsvm源代码的世界了。

二、Libsvm的真面目

由于本篇是介绍支持向量回归机--SVR的,所以仅从用到SVR算法的代码入手,来分析Libsvm的源代码。分析时采用逐层进入的方式。下面直接上程序代码目录:

+++++++目录+++++

前文是首先调用svm_train()然后调用svm_predict()。那就先从svm_train()说起。

class svm_train{}包含以下几个变量以及函数:

private svm_parameter param;  // 用于设置svm模型的参数

private svm_problem prob; // 用来存储样本序号、样本的目标变量Y、样本自变量X.详看class svm_problem

private svm_model model;// ??

private String input_file_name;  // 输入文件名

private String model_file_name;  // 模型文件名

private String error_msg;//错误信息

private int cross_validation;//交叉验证

private int nr_fold;// ??

private static svm_print_interface svm_print_null = new svm_print_interface() //

private static void exit_with_help() //打印帮助信息

private void do_cross_validation() //交叉验证

private void run(String argv[]) throws IOException //运行svm训练程序

public static void main(String argv[]) throws IOException // 主函数

private static double atof(String s) //将字符串转化为浮点型

private static int atoi(String s) //将字符串转化为int型

private void parse_command_line(String argv[]) //设置参数

private void read_problem() throws IOException //读取错误信息

前文自己写的那个函数,调用的svm训练模型,即

svm_train t = new svm_train();

t.main(arg); //调用svm训练模型

程序进入到class svm_train{ }中的main函数,即 public static void main(String argv[]) throws IOException 该主函数的代码如下:

public static void main(String argv[]) throws IOException {

svm_train t = new svm_train();

t.run(argv);//传进来一个数组,数组里面有两个字符串,一个是训练样本.txt,一个是训练好的模型.txt

}

继续执行run()函数,输入为数组argv[].

private void run(String argv[]) throws IOException{

parse_command_line(argv); // 1.进入到该函数中,获取SVM参数

read_problem(); // 2.进入到该函数中,读取错误信息

error_msg = svm.svm_check_parameter(prob,param); // 3.检查参数

//检查参数,有错误则返回各种参数错误信息,无错误则返回null;

if(error_msg != null)

{

System.err.print("ERROR: "+error_msg+"\n");

System.exit(1);

}

if(cross_validation != 0)

{

do_cross_validation(); // 4.交叉验证

}

else

{

model = svm.svm_train(prob,param); // 5.prob--训练样本,param--SVM模型参数

svm.svm_save_model(model_file_name,model); // 6.保存训练好的模型

}

}

注:该函数一共调用了6个函数,下文一一说明。

首先进入函数1:parse_command_line(argv); // 1.进入到该函数中,获取SVM参数。该函数的输入为argv[],即两个字符串:一个是训练样本.txt,一个是训练好的模型.txt。该函数虽然无返回值,但在函数里面,已经将svm的一些参数存储在param中了,详细参数名称见class svm_parameter,因此模型训练时已经有了所需要的各种参数。函数的详细代码如下:

private void _lineparse_command(String argv[])

{

int i;

svm_print_interface print_func = null; // default printing to stdout

param = new svm_parameter();//开始设置SVM模型的各种参数

// default values

//param.svm_type = svm_parameter.C_SVC;

param.svm_type = svm_parameter.EPSILON_SVR; //此时运行的是SVR算法

param.kernel_type = svm_parameter.RBF; //核函数取径向基核函数

param.degree = 3; //??

param.gamma = 0.08;

//gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2 sig=2.5

//RBF核函数:exp(-gamma*|Xi-Xj|^2)

param.coef0 = 0; //??

param.nu = 0.5; //??

param.cache_size = 100; //设置缓存的大小

param.C = 100; //惩罚参数

param.eps = 0.005; //??

param.p = 0.001; //此值为EPSILON_SVR中EPSILON

param.shrinking = 1; //??

param.probability = 0; //概率估计??

param.nr_weight = 0; //权重??

param.weight_label = new int[0]; //??

param.weight = new double[0];

cross_validation = 0; //交叉验证。0--不进行交叉验证。1--交叉验证

//获取输入参数

// parse options

for(i=0;i

{

if(argv[i].charAt(0) != '-') break;

//由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0

if(++i>=argv.length)

exit_with_help();

switch(argv[i-1].charAt(1))

{

case 's':

param.svm_type = atoi(argv[i]);

break;

case 't':

param.kernel_type = atoi(argv[i]);

break;

case 'd':

param.degree = atoi(argv[i]);

break;

case 'g':

param.gamma = atof(argv[i]);

break;

case 'r':

param.coef0 = atof(argv[i]);

break;

case 'n':

param.nu = atof(argv[i]);

break;

case 'm':

param.cache_size = atof(argv[i]);

break;

case 'c':

param.C = atof(argv[i]);

break;

case 'e':

param.eps = atof(argv[i]);

break;

case 'p':

param.p = atof(argv[i]);

break;

case 'h':

param.shrinking = atoi(argv[i]);

break;

case 'b':

param.probability = atoi(argv[i]);

break;

case 'q':

print_func = svm_print_null;

i--;

break;

case 'v':

cross_validation = 1;

nr_fold = atoi(argv[i]);

if(nr_fold < 2)

{

System.err.print("n-fold cross validation: n must >= 2\n");

exit_with_help();

}

break;

case 'w':

++param.nr_weight;

{

int[] old = param.weight_label;

param.weight_label = new int[param.nr_weight];

System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);

}

{

double[] old = param.weight;

param.weight = new double[param.nr_weight];

System.arraycopy(old,0,param.weight,0,param.nr_weight-1);

}

param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));

param.weight[param.nr_weight-1] = atof(argv[i]);

break;

default:

System.err.print("Unknown option: " + argv[i-1] + "\n");

exit_with_help();

} //switch循环结束

}//for循环结束

svm.svm_set_print_string_function(print_func); //1.1打印,详见下文说明2;

// determine filenames

if(i>=argv.length) //argv.length=2,而i=0,不执行此语句

exit_with_help();

input_file_name = argv[i]; //训练样本的文件名,即trainfile\\data_train_svr.txt

if(i

model_file_name = argv[i+1]; //模型文件名,即trainfile\\model_r.txt

else//此时不执行下面语句

{

int p = argv[i].lastIndexOf('/');

++p; // whew...

model_file_name = argv[i].substring(p)+".model";

}

}//函数parse_command_line结束

说明:

1.该函数的功能是初始化svm模型的各种参数,本篇用的是SVR算法,初始化了一些参数。

2.该函数调用了一个函数,即函数1.1,由于该函数的输入是print_func = null,经过调用,其输出也为空,即不打印任何信息,因此本文不予深入说明。

此时程序进入函数2:read_problem() ; // 2.进入到该函数中,读取错误信息

该函数无返回值,但在函数体内针对两个错误,用打印输出语句打印出相应的错误信息,其中一个错误信息为:核函数的第一列的标签必须从0开始编号。如果不是从0编号,则打印输出此错误信息。第二个错误为:样本格式有错误,如果样本的编号标签小于0或者样本的编号标签值大于样本的实际个数,则打印输出该错误信息。

private void read_problem() throws IOException {

BufferedReader fp = new BufferedReader(new FileReader(input_file_name));

Vector vy = new Vector();

Vector vx = new Vector();

int max_index = 0;

while(true)

{

String line = fp.readLine();

if(line == null) break;

StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");

vy.addElement(atof(st.nextToken()));//atof--将字符串转化为数字

int m = st.countTokens()/2;//训练样本的特征个数,Y,X1,X2...Xm-1.共m个

svm_node[] x = new svm_node[m];

for(int j=0;j

{

x[j] = new svm_node();

x[j].index = atoi(st.nextToken());

x[j].value = atof(st.nextToken());

}

if(m>0) max_index = Math.max(max_index, x[m-1].index);

vx.addElement(x);

}

prob = new svm_problem();

prob.l = vy.size();

prob.x = new svm_node[prob.l][];

for(int i=0;i

prob.x[i] = vx.elementAt(i);

prob.y = new double[prob.l];

for(int i=0;i

prob.y[i] = vy.elementAt(i);

if(param.gamma == 0 && max_index > 0)

param.gamma = 1.0/max_index;

if(param.kernel_type == svm_parameter.PRECOMPUTED)

for(int i=0;i

{

if (prob.x[i][0].index != 0)

{

System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");

System.exit(1);

}

if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)

{

System.err.print("Wrong input format: sample_serial_number out of range\n");

System.exit(1);

}

}

fp.close();

}

}

此时程序进入函数3:error_msg = svm.svm_check_parameter(prob,param); // 3.检查参数

该函数是在class svm中,功能是检查svm模型的参数是否正确。

public static String svm_check_parameter(svm_problem prob, svm_parameter param)

该函数的输入是svm_problem prob和svm_parameter param两个类,其中类svm_problem如下:

public class svm_problem implements java.io.Serializable{

public int l;//训练样本中,样本的标签,即第l个训练样本

public double[] y;//训练样本的目标变量Y

public svm_node[][] x;//训练样本的自变量X

}svm_parameter则是svm所需要的各种参数。

输出则是一个字符串,如果匹配到相应的错误,则输出其错误信息,如果没有错误,则返回NULL.

此时程序进入函数4:do_cross_validation(); // 4.交叉验证

由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。

这里留作以后详细研究。

此时程序进入函数5:model = svm.svm_train(prob,param); //prob--训练样本,param--SVM模型参数

此时进入到了SVM/SVR算法的关键环节--训练模型。欲知详情,请听下回分解。

-----格格-----------------

参考文献:

1.http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html-----很好的libsvm安装使用说明

  • 0
    点赞
  • 1
    收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:1024 设计师:我叫白小胖 返回首页
评论

打赏作者

异教徒

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值