参考文献:
LibSVM使用指南 http://www.cnblogs.com/zhangchaoyang/articles/2189606.html
3行程序搞定SVM分类-用JAVA程序调用LibSVM API 最简单的示例 http://blog.csdn.net/datoubo/article/details/8583558
写在前面:最好的学习资料是自带的readme文件。
1、数据准备
以下是readme中的内容:
The format of training and testing data file is:
<label> <index1>:<value1> <index2>:<value2> ...
.
.
.
Each line contains an instance and is ended by a '\n' character. For
classification, <label> is an integer indicating the class label
(multi-class is supported). For regression, <label> is the target
value which can be any real number. For one-class SVM, it's not used
so can be any number. The pair <index>:<value> gives a feature
(attribute) value: <index> is an integer starting from 1 and <value>
is a real number. The only exception is the precomputed kernel, where
<index> starts from 0; see the section of precomputed kernels. Indices
must be in ASCENDING order. Labels in the testing file are only used
to calculate accuracy or errors. If they are unknown, just fill the
first column with any numbers.
即: <label> <index1>:<value1> <index2>:<value2> ...
2、简单的Java调用
建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java和svm_predict.java这两个文件,这两个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。至于另外两个svm_tony.java和svm_scale可以不导入,它们分别是图形界面和数据压缩用的,不是必要文件。
public static void main(String[] args) throws IOException {
String []arg ={"-h","0",
"-v","5", //设置交叉验证数目
"./data/weibo", //存放SVM训练模型用的数据的路径
"./data/model"}; //存放SVM通过训练数据训练出来的模型的路径
/* String []parg={"./data/weibo", //这个是存放测试数据
"./data/model", //调用的是训练以后的模型
"./data/out.txt"}; //生成的结果的文件的路径
*/
System.out.println("........SVM is runnig..........");
//创建一个训练对象
svm_train t = new svm_train();
//创建一个预测或者分类的对象
svm_predict p= new svm_predict();
t.main(arg); //使用模型训练
// p.main(parg); //调用模型测试
}
对于文本分类svm_train中有几个选项会用到:
-s svm_type : set type of SVM (default 0)
0 -- C-SVC
-t kernel_type : set type of kernel function (default 2)
0 -- linear: u'*v
2 -- radial basis function: exp(-gamma*|u-v|^2)
-g gamma : set gamma in kernel function (default 1/num_features) num_features是输入向量的个数
-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
-m cachesize : set cache memory size in MB (default 100) 使用多少内存
-e epsilon : set tolerance of termination criterion (default 0.001)
-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)
-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1) 当各类数量不均衡时为每个类分别指定C
-v n: n-fold cross validation mode交叉验证时分为多少组
-q : quiet mode (no outputs)
libSVM的使用主要有两种方法:第一种是基于训练样本训练分类model(分类超平面),然后对测试样本进行分类;第二种直接对训练样本采用n折交叉验证法测试SVM分类性能。
附:svm 主调用程序经常使用的几个方法,注意这是LibSVM中SVM类下的函数,与svm_train.java等文件中的相关函数要区分:
svm.svm_train(svm_problem,svm_parameter) 该方法返回一个训练好的svm_model
svm.svm_load_model(文件名); 该方法返回一个训练好的svm_model
svm.svm_save_model(文件名,svm_model); 该方法将svm_model保存到文件中
svm.svm_predict_values(svm_model,svm_node,double); 该方法返回doule类值,svm_node对svm_model测试,返回值确定了svm_node在模型中的定位
3、结果分析
模型训练后会生成model文件,内容如下:
svm_type c_svc
kernel_type rbf
gamma 0.5
nr_class 2
total_sv 9
rho -0.5061570424019811
label 1 0
nr_sv 4 5
SV
2.7686973549711875 1:0.21428571428571427 2:0.3333333333333333
5.0 1:0.35714285714285715 2:0.75
5.0 1:0.8571428571428571 2:0.08333333333333333
5.0 1:0.5714285714285714 2:0.5833333333333334
-5.0 1:0.6428571428571429 2:0.6666666666666666
-2.4351637665059895 1:0.42857142857142855 2:1.0
-5.0 1:0.7142857142857143 2:0.6666666666666666
-5.0 1:0.5714285714285714 2:0.4166666666666667
-0.3335335884651968 1:1.0 2:0.6666666666666666
nr_class代表训练样本集有几类,
rho是判决函数的常数项b,
nr_sv是各个类中落在边界上的向量个数,
SV下面枚举了所有的支持向量,每个支持向量前面都有一个数字,代表什么我现在也不清楚。
当训练模型时控制台会有类似下面的输出:
optimization finished, #iter = 219
nu = 0.431030
obj = -100.877286, rho = 0.424632
nSV = 132, nBSV = 107
Total nSV = 132
obj是对SSVM问题的优化目标函数的值。rho是决策函数中的常数项b。nSV是支持向量的个数,nBSV是边界支持向量的个数(i.e., alpha_i = C)。
如果“自由支持向量”个数很多,很可能是因为过拟合了。如果输入数据的attribute在一个很大的范围内分布,最好scale一下。
采用时默认的核函数RBF是比较好的,if RBF is used with model selection, then there is no need to consider the linear kernel.
如果预测的准确率太低,如何提高一下?使用python目录下的grid.py进行模型选择以找到比较好的参数。
grid.py是一种用于RBF核函数的C-SVM分类的参数选择程序。用户只需给定参数的一个范围,grid.py采用交叉验证的方法计算每种参数组合的准确度来找到最好的参数。