java学习--Libsvm java版代码注释及详解(一)

由于工作中要用到SVR算法,项目组的系统是用java开发的,因此,为了能与项目组同步,算法需要用java来实现,还好台湾大学的林智仁教授推出了Libsvm的源代码,包括java、c++等语言的源代码,在此表示感谢!因此,算法的主体部分不用自己开发了,在源代码的基础上做一些修改就能够应用到自己的项目中了,开源真好!受益了无数人。。。为了弘扬开源的精神,开博记录学习Libsvmjava版源代码的过程。下面正式开始!先从SVR回归算法的代码开始,然后逐步扩展到分类算法。希望自己能够坚持下去。加油。

 

---------------------我是华丽的分割线-----------高手的小jj,我割割割--嘻嘻-------------------

一、初识Libsvm

    Libsvmjava版本的源代码很容易下载到,为了使算法能够运行,要先把源代码复制到myeclipse中(貌似是废话),在这里,提供一个链接,里面有很好的说明,按照文档中的说明就能够将Libsvm运行起来。

链接地址:http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html

   按照上述链接中的方式,自己新建一个main函数,来调用Libsvm算法的源代码,代码如下:

public static void main(String[] args) throws IOException {

   

    String[]arg ={ "trainfile\\train1.txt", //存放SVM训练模型用的数据的路径

                  "trainfile\\model_r.txt"}; //存放SVM通过训练数据训练出来的模型的路径

 

    String[]parg={"trainfile\\test2.txt",  //这个是存放测试数据

                   "trainfile\\model_r.txt", //调用的是训练以后的模型

                   "trainfile\\out_r.txt"}; //生成的结果的文件的路径

    System.out.println("........SVM运行开始.........."); 

    //创建一个训练对象

    svm_traint = new svm_train(); 

    //创建一个预测或者分类的对象

    svm_predictp= new svm_predict(); 

    t.main(arg);  //调用

    p.main(parg); //调用

}

注意:

1. 该主函数是为了调用Libsvm源代码的,创建了两个字符数组:arg[]parg[]

   其中arg[]数组存放了两个字符串,"trainfile\\train1.txt"和"trainfile\\model_r.txt",这两个字符串传入到Libsvm中的svm_train();

   trainfile\\train1.txt这个文件里面存储的是训练数据,训练数据的格式如下:

     index1:特征1  index2:特征2   ....

<lable1> 1:特征1      2:特征2      3:特征3 ...

<lable2> 1:特征1      2:特征2      3:特征3 ...

其中如果是分类问题,lable为类标签。如果是回归问题,lable为具体的实数。为了生成这种格式的数据,可以采用FormatDataLibsvm.exl这个excel文件生成,网上可以下的到。当然,也可以自己写java代码,来生成这种格式的数据。代码随后奉上。而trainfile\\model_r.txt这个文件的作用是存储利用SVM算法训练好的模型。

   parg[]这个数组存放了3个字符串:trainfile\\test2.txt---用于存储测试样本的文件,数据格式与训练样本的格式一样、trainfile\\model_r.txt---前面生成的训练好的模型,在用新样本进行预测时直接使用前文训练好的模型即可、trainfile\\model_r.txt---用于存储模型的预测值,分类问题的话存储的是预测样本每条样本所属的类别,而回归问题的话,存储的是每条样本所对应的预测值。

2. svm_train()与svm_predict()。这两个函数是Libsvm程序包中的源代码,由此即进入到了Libsvm源代码的世界了。

 

二、Libsvm的真面目

   由于本篇是介绍支持向量回归机--SVR的,所以仅从用到SVR算法的代码入手,来分析Libsvm的源代码。分析时采用逐层进入的方式。下面直接上程序代码目录:

 

+++++++目录+++++

 

   前文是首先调用svm_train()然后调用svm_predict()。那就先从svm_train()说起。

 classsvm_train{}包含以下几个变量以及函数

 private svm_parameterparam // 用于设置svm模型的参数
 private svm_problemprob// 用来存储样本序号、样本的目标变量Y、样本自变量X.详看classsvm_problem
 private svm_modelmodel;// ??
 private Stringinput_file_name // 输入文件名
 private Stringmodel_file_name // 模型文件名
 private Stringerror_msg;//错误信息
 private intcross_validation;//交叉验证
 private intnr_fold;// ??

 private staticsvm_print_interface svm_print_null = newsvm_print_interface() //

 private static voidexit_with_help() //打印帮助信息

 private voiddo_cross_validation() //交叉验证

 private void run(Stringargv[]) throws IOException //运行svm训练程序

 public static voidmain(String argv[]) throws IOException// 主函数

 private static doubleatof(String s) //将字符串转化为浮点型

 private static intatoi(String s) //将字符串转化为int型

 private voidparse_command_line(String argv[]) //设置参数

 private voidread_problem() throws IOException //读取错误信息

 

前文自己写的那个函数,调用的svm训练模型,即

svm_train t = new svm_train();

t.main(arg); //调用svm训练模型

程序进入到class svm_train{ }中的main函数,即 publicstatic void main(String argv[]) throws IOException该主函数的代码如下:

public static void main(String argv[])throws IOException
 {
  svm_train t = newsvm_train();
  t.run(argv);//传进来一个数组,数组里面有两个字符串,一个是训练样本.txt,一个是训练好的模型.txt
 }

继续执行run()函数,输入为数组argv[].

private void run(String argv[]) throwsIOException{
    parse_command_line(argv); // 1.进入到该函数中,获取SVM参数
    read_problem(); // 2.进入到该函数中,读取错误信息
    error_msg =svm.svm_check_parameter(prob,param);// 3.检查参数

   //检查参数,有错误则返回各种参数错误信息,无错误则返回null;

   if(error_msg != null)
    {
      System.err.print("ERROR: "+error_msg+"\n");
      System.exit(1);
    }

   if(cross_validation != 0)
    {
       do_cross_validation(); // 4.交叉验证
    }
    else
    {
      model = svm.svm_train(prob,param); //5.prob--训练样本,param--SVM模型参数
      svm.svm_save_model(model_file_name,model); //6.保存训练好的模型
    }

}

注:该函数一共调用了6个函数,下文一一说明。

首先进入函数1:parse_command_line(argv); // 1.进入到该函数中,获取SVM参数。该函数的输入为argv[],即两个字符串:一个是训练样本.txt,一个是训练好的模型.txt。该函数虽然无返回值,但在函数里面,已经将svm的一些参数存储在param中了,详细参数名称见classsvm_parameter,因此模型训练时已经有了所需要的各种参数。函数的详细代码如下:

private void _lineparse_command(Stringargv[])

{
   int i;
   svm_print_interface print_func= null; // default printing to stdout

   param = newsvm_parameter();//开始设置SVM模型的各种参数
   // default values
   //param.svm_type =svm_parameter.C_SVC;
  
   param.svm_type = svm_parameter.EPSILON_SVR;//此时运行的是SVR算法
   param.kernel_type = svm_parameter.RBF;//核函数取径向基核函数
   param.degree = 3; //??
   param.gamma = 0.08; 

  //gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2sig=2.5

  //RBF核函数:exp(-gamma*|Xi-Xj|^2)
   param.coef0 = 0; //??
   param.nu= 0.5; //??
  
   param.cache_size = 100; //设置缓存的大小
   param.C= 100; //惩罚参数
   param.eps = 0.005; //??
   param.p= 0.001; //此值为EPSILON_SVR中EPSILON
   param.shrinking = 1; //??
   param.probability = 0; //概率估计??
   param.nr_weight = 0; //权重??
   param.weight_label = new int[0]; //??
   param.weight = new double[0];
   cross_validation = 0;//交叉验证。0--不进行交叉验证。1--交叉验证
       
   //获取输入参数
   // parse options
  for(i=0;i<argv.length;i++) //argv.length=2,即有两个字符串
    {
     if(argv[i].charAt(0) != '-') break;

    //由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0
     if(++i>=argv.length)
     exit_with_help();
     switch(argv[i-1].charAt(1))
       {
            case 's':
                param.svm_type = atoi(argv[i]);
                break;
            case 't':
                param.kernel_type = atoi(argv[i]);
                break;
            case 'd':
                param.degree = atoi(argv[i]);
                break;
            case 'g':
                param.gamma = atof(argv[i]);
                break;
            case 'r':
                param.coef0 = atof(argv[i]);
                break;
            case 'n':
                 param.nu= atof(argv[i]);
                break;
            case 'm':
                param.cache_size = atof(argv[i]);
                break;
             case'c':
                param.C = atof(argv[i]);
                break;
             case'e':
                param.eps = atof(argv[i]);
                break;
            case 'p':
                param.p = atof(argv[i]);
                break;
            case 'h':
                param.shrinking = atoi(argv[i]);
                break;
            case 'b':
                 param.probability= atoi(argv[i]);
                break;
            case 'q':
                print_func = svm_print_null;
                i--;
                break;
            case 'v':
                cross_validation = 1;
                nr_fold = atoi(argv[i]);
                if(nr_fold < 2)
                  {
                    System.err.print("n-fold cross validation: n must >=2\n");
                    exit_with_help();
                  }
                break;
          case'w':
                ++param.nr_weight;
                 {
                   int[] old = param.weight_label;
                   param.weight_label = new int[param.nr_weight];
                   System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
                 }

                {
                   double[] old = param.weight;
                   param.weight = new double[param.nr_weight];
                    System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
                }

               param.weight_label[param.nr_weight-1] =atoi(argv[i-1].substring(2));
               param.weight[param.nr_weight-1] = atof(argv[i]);
               break;
          default:
               System.err.print("Unknown option: " + argv[i-1] + "\n");
               exit_with_help();
        } //switch循环结束
     }//for循环结束
        
     svm.svm_set_print_string_function(print_func);//1.1打印,详见下文说明2;

     // determine filenames

     if(i>=argv.length) //argv.length=2,而i=0,不执行此语句
          exit_with_help();

     input_file_name =argv[i]; //训练样本的文件名,即trainfile\\data_train_svr.txt

     if(i<argv.length-1)//i=0,argv.length-1=1,符合条件
          model_file_name = argv[i+1];//模型文件名,即trainfile\\model_r.txt
      else//此时不执行下面语句
     {
          int p = argv[i].lastIndexOf('/');
           ++p; //whew...
          model_file_name = argv[i].substring(p)+".model";
      }
 }//函数parse_command_line结束

 

 说明:

1.该函数的功能是初始化svm模型的各种参数,本篇用的是SVR算法,初始化了一些参数。

2.该函数调用了一个函数,即函数1.1,由于该函数的输入是print_func =null,经过调用,其输出也为空,即不打印任何信息,因此本文不予深入说明。

 

此时程序进入函数2:read_problem() ; // 2.进入到该函数中,读取错误信息

该函数无返回值,但在函数体内针对两个错误,用打印输出语句打印出相应的错误信息,其中一个错误信息为:核函数的第一列的标签必须从0开始编号。如果不是从0编号,则打印输出此错误信息。第二个错误为:样本格式有错误,如果样本的编号标签小于0或者样本的编号标签值大于样本的实际个数,则打印输出该错误信息。

 private voidread_problem() throwsIOException
 {
     BufferedReaderfp = new BufferedReader(new FileReader(input_file_name));
     Vector<Double>vy = new Vector<Double>();
    Vector<svm_node[]> vx = newVector<svm_node[]>();
    int max_index = 0;

    while(true)
     {
         String line = fp.readLine();
         if(line == null) break;

         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");

         vy.addElement(atof(st.nextToken()));//atof--将字符串转化为数字
         int m = st.countTokens()/2;//训练样本的特征个数,Y,X1,X2...Xm-1.共m个
         svm_node[] x = new svm_node[m];
         for(int j=0;j<m;j++)
           {
               x[j] = new svm_node();
               x[j].index = atoi(st.nextToken());
               x[j].value = atof(st.nextToken());
            }
        if(m>0) max_index = Math.max(max_index,x[m-1].index);
         vx.addElement(x);
     }

   prob = newsvm_problem();
   prob.l = vy.size();
   prob.x = newsvm_node[prob.l][];
   for(inti=0;i<prob.l;i++)
      prob.x[i] = vx.elementAt(i);
   prob.y = newdouble[prob.l];
   for(inti=0;i<prob.l;i++)
      prob.y[i] = vy.elementAt(i);

   if(param.gamma == 0&& max_index >0)
      param.gamma = 1.0/max_index;

   if(param.kernel_type ==svm_parameter.PRECOMPUTED)
      for(int i=0;i<prob.l;i++)
         {
           if(prob.x[i][0].index != 0)
            {
         System.err.print("Wrong kernel matrix: first column must be0:sample_serial_number\n");
         System.exit(1);
             }
         if ((int)prob.x[i][0].value <= 0 ||(int)prob.x[i][0].value > max_index)
            {
          System.err.print("Wrong input format:sample_serial_number out of range\n");
           System.exit(1);
             }
        }

   fp.close();
  }

}

 

 

 

 此时程序进入函数3:error_msg =svm.svm_check_parameter(prob,param);// 3.检查参数

 该函数是在class svm中,功能是检查svm模型的参数是否正确。

 public static Stringsvm_check_parameter(svm_problem prob, svm_parameterparam) 

 该函数的输入是svm_problemprobsvm_parameterparam两个类,其中类svm_problem如下:

 public class svm_problemimplements java.io.Serializable
{
    public intl;//训练样本中,样本的标签,即第l个训练样本
    publicdouble[] y;//训练样本的目标变量Y
    publicsvm_node[][] x;//训练样本的自变量X
}
svm_parameter则是svm所需要的各种参数。

输出则是一个字符串,如果匹配到相应的错误,则输出其错误信息,如果没有错误,则返回NULL.

 

 

 此时程序进入函数4:do_cross_validation(); // 4.交叉验证

 由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。

 这里留作以后详细研究。

 

此时程序进入函数5:model =svm.svm_train(prob,param); //prob--训练样本,param--SVM模型参数

此时进入到了SVM/SVR算法的关键环节--训练模型。欲知详情,请听下回分解。

 

-----格格-----------------

 

参考文献:

1.http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html-----很好的libsvm安装使用说明

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值