package weka_test;
import java.io.File;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.experiment.InstanceQuery;
import weka.classifiers.Evaluation;
import java.util.Random;
public class test {
/**
* oracleInput
* @return data
* @throws Exception
*/
public static Instances oracleInput() throws Exception{
InstanceQuery query = new InstanceQuery();
String sql = "SELECT to_char(z.cydate,'yyyy/mm') AS d,sum(z.bcmoney) as c FROM zybc z"
+ " WHERE to_char(z.cydate,'yyyy/mm') IS NOT NULL"
+ " GROUP BY to_char(z.cydate,'yyyy/mm') ORDER BY to_date(to_char(z.cydate,'yyyy/mm'),'yyyy/mm') ASC";
//System.out.println(sql);
query.setCustomPropsFile(new File("weka/weka_oracle.props"));
query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.133:1521/XE");
query.setUsername("***");
query.setPassword("***");
query.setQuery(sql);
Instances data = query.retrieveInstances();
return data;
}
/**
* mysqlInput
* @return data
* @throws Exception
*/
public static Instances mysqlInput() throws Exception{
InstanceQuery query = new InstanceQuery();
String sql = "SELECT * FROM iris";
//System.out.println(sql);
query.setCustomPropsFile(new File("weka/weka_mysql.props"));
query.setDatabaseURL("jdbc:mysql://localhost:3306/test");
query.setUsername("***");
query.setPassword("***");
query.setQuery(sql);
Instances data = query.retrieveInstances();
return data;
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Classifier m_classifier = new J48();
/*File inputFile = new File("D://Program Files//Weka-3-7//data//iris.arff");//训练语料文件
ArffLoader atf = new ArffLoader();
atf.setFile(inputFile);
Instances instancesTrain = atf.getDataSet(); // 读入训练文件 */
Instances Train = mysqlInput();
Instances Test = mysqlInput();
Test.setClassIndex(4); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
double sum = Test.numInstances(),//测试语料实例数
right = 0.0f;
Train.setClassIndex(4);
m_classifier.buildClassifier(Train); //训练
//System.out.println(m_classifier.toString());
//2、利用模型进行预测
int a=0,b=0,c=0,d=0;//记录每个类别的个数,方便计算评价指标
for (int i = 0; i < Test.numInstances(); i++) {
double classification = m_classifier.classifyInstance(Train.instance(i));
double classValue = Train.instance(i).classValue();
if (classification == 0.0 && classValue == 0.0) {
a++;
} else if (classification == 0.0 && classValue == 1.0) {
b++;
} else if (classification == 1.0 && classValue == 0.0) {
c++;
} else if (classification == 1.0 && classValue == 1.0) {
d++;
}
}
// 3、得出预测效果评测指标
double precision = (double) a / (a + b);
double recall = (double) a / (a + c);
double fMeasure = 2 * precision * recall / (precision + recall);
System.out.println("precision\trecall\tF-Measure");
System.out.println((precision) + "\t\t"
+ (recall) + "\t"
+ (fMeasure));
for(int i = 0;i
{
if(m_classifier.classifyInstance(Test.instance(i))==Test.instance(i).classValue())//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义)
{
right++;//正确值加1
}
}
String s=right+","+sum+",";
System.out.println("classification precision:"+s+(right/sum));
//直接调用Evaluation即可完成
Evaluation eval = null;
for (int i = 0; i < 10; i++) {
eval = new Evaluation(Train);
eval.crossValidateModel(m_classifier, Train, 10, new Random(i),
args);// 实现交叉验证模型
}
System.out.println(eval.toSummaryString());// 输出总结信息
System.out.println(eval.toClassDetailsString());// 输出分类详细信息
System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵
}
}