1. LibSVM簡介
LibSVM是台灣著名教授陳智仁團隊的傑作。具有各個語言版本的接口,包括C/C++、Java、Python、Matlab、C# 等等。這套庫運算速度還是挺快的,可以很方便的對數據做分類或回歸。由於libSVM程序小,運用靈活,輸入參數少,並且是開源的,易於擴展,因此成為目前國內應用最多的SVM的庫。
這套庫可以從http://www.csie.ntu.edu.tw/~cjlin/免費獲得,目前已經發展到3.22版。下載.zip格式的版本,解壓后可以看到,主要有5個文件夾和一些c++源碼文件。
2. 數據准備
Readme里面幾乎包含了所有可以幫助你靈活使用LibSVM的信息,可是很多人都不怎么看。這里給出用Java調用LibSVM API最簡單的示例,用JAVA進行SVM分類只需要幾行程序就搞定了,前提是你已經准備好了符合LibSVM處理數據格式的訓練樣本和測試樣本。下面一一道來:
2.1 准備訓練樣本和測試樣本
直接上LibSVM官網就可以下載,我下載的UCI的UCI-breast-cancer數據集,訓練樣本和測試樣本的基本格式是這樣的: :: ...
分別代表: 類別 feature1索引:feature1值 feature2索引:feature2值…
3. Java API
建立JAVA工程,導入LibSVM 的JAR包,要注意還需要導入java文件下的svm_train.java和svm_predict.java這兩個文件,這兩個類其實主要在LibSVM基礎上做了進一步封裝,把命令行參數轉化成了String []類型的函數參數,方便API調用。至於另外兩個svm_tony.java和svm_scale可以不導入,它們分別是圖形界面和數據壓縮用的,不是必要文件。
把訓練樣本和測試樣本放在工程文件夾下,當然,你也可以自定義data目錄。
修改svm_train.java和svm_predict.java這兩個文件,前者主要是把model_file_name返回,因為在svm_predict的main函數中需要使用,后者主要是把分類的Accuracy返回。
編寫JAVA調用LibSVM API分類代碼如下,非常簡單,代碼中給出了注釋import java.io.IOException;
import libsvm.*;
/**JAVA test code for LibSVM
* Created by zhanghuayan on 2017/1/3.
*/
public class LibSVMTest {
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
//Test for svm_train and svm_predict
//svm_train:
// param: String[], parse result of command line parameter of svm-train
// return: String, the directory of modelFile
//svm_predect:
// param: String[], parse result of command line parameter of svm-predict, including the modelfile
// return: Double, the accuracy of SVM classification
String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
String modelFile = svm_train.main(trainArgs);
String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
Double accuracy = svm_predict.main(testArgs);
System.out.println("SVM Classification is done! The accuracy is " + accuracy);
//Test for cross validation
//String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
//modelFile = svm_train.main(crossValidationTrainArgs);
//System.out.print("Cross validation is done! The modelFile is " + modelFile);
}
}
對機器學習,人工智能感興趣的小伙伴,請關注我的公眾號: