1. LibSVM简介
LibSVM是台湾著名教授陈智仁团队的杰作。具有各个语言版本的接口,包括C/C++、Java、Python、Matlab、C# 等等。这套库运算速度还是挺快的,可以很方便的对数据做分类或回归。由于libSVM程序小,运用灵活,输入参数少,并且是开源的,易于扩展,因此成为目前国内应用最多的SVM的库。
这套库可以从http://www.csie.ntu.edu.tw/~cjlin/免费获得,目前已经发展到3.22版。下载.zip格式的版本,解压后可以看到,主要有5个文件夹和一些c++源码文件。
2. 数据准备
Readme里面几乎包含了所有可以帮助你灵活使用LibSVM的信息,可是很多人都不怎么看。这里给出用Java调用LibSVM API最简单的示例,用JAVA进行SVM分类只需要几行程序就搞定了,前提是你已经准备好了符合LibSVM处理数据格式的训练样本和测试样本。下面一一道来:
2.1 准备训练样本和测试样本
直接上LibSVM官网就可以下载,我下载的UCI的UCI-breast-cancer数据集,训练样本和测试样本的基本格式是这样的:
<label> <index1>:<value1> <index2>:<value2> ...
分别代表: 类别 feature1索引:feature1值 feature2索引:feature2值…
3. Java API
建立JAVA工程,导入LibSVM 的JAR包,要注意还需要导入java文件下的svm_train.java和svm_predict.java这两个文件,这两个类其实主要在LibSVM基础上做了进一步封装,把命令行参数转化成了String []类型的函数参数,方便API调用。至于另外两个svm_tony.java和svm_scale可以不导入,它们分别是图形界面和数据压缩用的,不是必要文件。
把训练样本和测试样本放在工程文件夹下,当然,你也可以自定义data目录。
修改svm_train.java和svm_predict.java这两个文件,前者主要是把model_file_name返回,因为在svm_predict的main函数中需要使用,后者主要是把分类的Accuracy返回。
编写JAVA调用LibSVM API分类代码如下,非常简单,代码中给出了注释
import java.io.IOException;
import libsvm.*;
/**JAVA test code for LibSVM
* Created by zhanghuayan on 2017/1/3.
*/
public class LibSVMTest {
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
// param: String[], parse result of command line parameter of svm-train
// return: String, the directory of modelFile
//svm_predect:
// param: String[], parse result of command line parameter of svm-predict, including the modelfile
// return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
String modelFile = svm_train.main(trainArgs);
String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);
//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);
}
}
对机器学习,人工智能感兴趣的小伙伴,请关注我的公众号: