我觉得首先有必要简单说说交叉验证,即用只有一个训练集的时候,用一部分数据训练,一部分做测试,当然怎么分配及时不同的方法了。
1)k-folder cross-validation:
k个子集,每个子集均做一次测试集,其余的作为训练集。交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别正确率作为结果。优点:所有的样本都被作为了训练集和测试集,每个样本都被验证一次。10-folder通常被使用。
2)K * 2 folder cross-validation
是k-folder cross-validation的一个变体,对每一个folder,都平均分成两个集合s0,s1,我们先在集合s0训练用s1测试,然后用s1训练s0测试。优点是:测试和训练集都足够大,每一个个样本都被作为训练集和测试集。一般使用k=10
3)least-one-out cross-validation(loocv)
假设dataset中有n个样本,那LOOCV也就是n-CV,意思是每个样本单独作为一次测试集,剩余n-1个样本则做为训练集。优点:
1)每一回合中几乎所有的样本皆用于训练model,因此最接近母体样本的分布,估测所得的generalization error比较可靠。
2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。
但LOOCV的缺点则是计算成本高,为需要建立的models数量与总样本数量相同,当总样本数量相当多时,LOOCV在实作上便有困难,除非每次训练model的速度很快,或是可以用平行化计算减少计算所需的时间。
关键代码:
//直接调用Evaluation即可完成 Evaluation eval = null; for (int i = 0; i < 10; i++) { eval = new Evaluation(Train); eval.crossValidateModel(m_classifier, Train, 10, new Random(i), args);// 实现交叉验证模型 } System.out.println(eval.toSummaryString());// 输出总结信息 System.out.println(eval.toClassDetailsString());// 输出分类详细信息 System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵
Java调用weka实现算法,并保存模型,以及读取。
这个在网上找了很久,没找到,却偶然一次发现了,其实很简单,只要因为好一点的话,看国外论坛就好多了。
保存模型方法:
保存模型方法:
加载模型:SerializationHelper.write("LibSVM.model", classifier4);//参数一为模型保存文件,classifier4为要保存的模型
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");
全部代码:
package weka_test;
import java.io.File;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.experiment.InstanceQuery;
import weka.classifiers.Evaluation;
import java.util.Random;
public class test {
/**
* oracleInput
* @return data
* @throws Exception
*/
public static Instances oracleInput() throws Exception{
InstanceQuery query = new InstanceQuery();
String sql = "SELECT to_char(z.cydate,'yyyy/mm') AS d,sum(z.bcmoney) as c FROM zybc z"
+ " WHERE to_char(z.cydate,'yyyy/mm') IS NOT NULL"
+ " GROUP BY to_char(z.cydate,'yyyy/mm') ORDER BY to_date(to_char(z.cydate,'yyyy/mm'),'yyyy/mm') ASC";
//System.out.println(sql);
query.setCustomPropsFile(new File("weka/weka_oracle.props"));
query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.133:1521/XE");
query.setUsername("***");
query.setPassword("***");
query.setQuery(sql);
Instances data = query.retrieveInstances();
return data;
}
/**
* mysqlInput
* @return data
* @throws Exception
*/
public static Instances mysqlInput() throws Exception{
InstanceQuery query = new InstanceQuery();
String sql = "SELECT * FROM iris";
//System.out.println(sql);
query.setCustomPropsFile(new File("weka/weka_mysql.props"));
query.setDatabaseURL("jdbc:mysql://localhost:3306/test");
query.setUsername("***");
query.setPassword("***");
query.setQuery(sql);
Instances data = query.retrieveInstances();
return data;
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Classifier m_classifier = new J48();
/*File inputFile = new File("D://Program Files//Weka-3-7//data//iris.arff");//训练语料文件
ArffLoader atf = new ArffLoader();
atf.setFile(inputFile);
Instances instancesTrain = atf.getDataSet(); // 读入训练文件 */
Instances Train = mysqlInput();
Instances Test = mysqlInput();
Test.setClassIndex(4); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
double sum = Test.numInstances(),//测试语料实例数
right = 0.0f;
Train.setClassIndex(4);
m_classifier.buildClassifier(Train); //训练
//System.out.println(m_classifier.toString());
//2、利用模型进行预测
int a=0,b=0,c=0,d=0;//记录每个类别的个数,方便计算评价指标
for (int i = 0; i < Test.numInstances(); i++) {
double classification = m_classifier.classifyInstance(Train.instance(i));
double classValue = Train.instance(i).classValue();
if (classification == 0.0 && classValue == 0.0) {
a++;
} else if (classification == 0.0 && classValue == 1.0) {
b++;
} else if (classification == 1.0 && classValue == 0.0) {
c++;
} else if (classification == 1.0 && classValue == 1.0) {
d++;
}
}
// 3、得出预测效果评测指标
double precision = (double) a / (a + b);
double recall = (double) a / (a + c);
double fMeasure = 2 * precision * recall / (precision + recall);
System.out.println("precision\trecall\tF-Measure");
System.out.println((precision) + "\t\t"
+ (recall) + "\t"
+ (fMeasure));
for(int i = 0;i<sum;i++)//测试分类结果
{
if(m_classifier.classifyInstance(Test.instance(i))==Test.instance(i).classValue())//如果预测值和答案值相等(测试语料中的分类列提供的须为正确答案,结果才有意义)
{
right++;//正确值加1
}
}
String s=right+","+sum+",";
System.out.println("classification precision:"+s+(right/sum));
//直接调用Evaluation即可完成
Evaluation eval = null;
for (int i = 0; i < 10; i++) {
eval = new Evaluation(Train);
eval.crossValidateModel(m_classifier, Train, 10, new Random(i),
args);// 实现交叉验证模型
}
System.out.println(eval.toSummaryString());// 输出总结信息
System.out.println(eval.toClassDetailsString());// 输出分类详细信息
System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵
}
}
python sklearn数据预处理:
广义线性模型--Generalized Linear Models