1.配置
MyEclipse2013+Weka3.6+libsvm3.18+Jdk1.7+Win8.1
2.小Tips
1). Java使用Weka
实现:
将安装文件夹里的weka.jar加入项目编译路径中
2). CSV文件可以转换成Arff文件
实现:
运行Weka的Explorer界面,打开csv文件,保存为arff文件。
注意:
如果有训练集和测试集,将训练集的Arff文件的标签头复制到测试集的Arff文件!
3). Java通过Weka使用LibSVM
实现:
将LibSVM文件夹里的libsvm.jar加入项目编译路径中
3.示例
public static void main(String[] args) {
try {
Classifier classifier1;
Classifier classifier2;
Classifier classifier3;
Classifier classifier4;
File inputFile = new File(
"C:\\Users\\zhangzhizhi\\Documents\\Everyone\\张志智\\总结积累\\Weka\\change_train.arff");// 训练语料文件
ArffLoader atf = new ArffLoader();
atf.setFile(inputFile);
Instances instancesTrain = atf.getDataSet(); // 读入训练文件
inputFile = new File(
"C:\\Users\\zhangzhizhi\\Documents\\Everyone\\张志智\\总结积累\\Weka\\change_test.arff");// 测试语料文件
atf.setFile(inputFile);
Instances instancesTest = atf.getDataSet(); // 读入测试文件
instancesTest.setClassIndex(0);
instancesTrain.setClassIndex(0);
// 朴素贝叶斯算法
classifier1 = (Classifier) Class.forName(
"weka.classifiers.bayes.NaiveBayes").newInstance();
// 决策树
classifier2 = (Classifier) Class.forName(
"weka.classifiers.trees.J48").newInstance();
// Zero
classifier3 = (Classifier) Class.forName(
"weka.classifiers.rules.ZeroR").newInstance();
// LibSVM
classifier4 = (Classifier) Class.forName(
"weka.classifiers.functions.LibSVM").newInstance();
classifier4.buildClassifier(instancesTrain);
classifier1.buildClassifier(instancesTrain);
classifier2.buildClassifier(instancesTrain);
classifier3.buildClassifier(instancesTrain);
Evaluation eval = new Evaluation(instancesTrain);
eval.evaluateModel(classifier4, instancesTest);
System.out.println(eval.errorRate());
eval.evaluateModel(classifier1, instancesTest);
System.out.println(eval.errorRate());
eval.evaluateModel(classifier2, instancesTest);
System.out.println(eval.errorRate());
eval.evaluateModel(classifier3, instancesTest);
System.out.println(eval.errorRate());
} catch (Exception e) {
e.printStackTrace();
}
}
如果只有训练集,采用十交叉验证的方法,将上面的第5步和第6步更改为如下代码:
Evaluation eval = new Evaluation(instancesTrain);
eval.crossValidateModel(classifier4, instancesTrain, 10, new Random(1));
System.out.println(eval.errorRate());
eval.crossValidateModel(classifier1, instancesTrain, 10, new Random(1));
System.out.println(eval.errorRate());
eval.crossValidateModel(classifier2, instancesTrain, 10, new Random(1));
System.out.println(eval.errorRate());
eval.crossValidateModel(classifier3, instancesTrain, 10, new Random(1));
System.out.println(eval.errorRate());
如果需要保存和加载分类器模型参数,在第5步和第6步之间加入如下代码:
SerializationHelper.write("LibSVM.model", classifier4);
SerializationHelper.write("NaiveBayes.model", classifier1);
SerializationHelper.write("J48.model", classifier2);
SerializationHelper.write("ZeroR.model", classifier3);
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");
Classifier classifier5 = (Classifier) weka.core.SerializationHelper.read("NaiveBayes.model");
Classifier classifier6 = (Classifier) weka.core.SerializationHelper.read("J48.model");
Classifier classifier7 = (Classifier) weka.core.SerializationHelper.read("ZeroR.model");