水平有限,随便翻译,能看就行
原址:http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29
文章主要讲述了如何直接使用wekaAPI产生交叉验证训练/测试splits
使用如下变量:
Instances data = ...; // contains the full dataset we wann create train/test sets from
int seed = ...; // the seed for randomizing the data
int folds = ...; // the number of folds to generate, >=2
随机化数据
首先随机化你的数据
Random rand = new Random(seed); // create seeded number generator
randData = new Instances(data); // create copy of original data
randData.randomize(rand); // randomize data with number generator
如果你的数据含有 nominal类,并且你想要使用分层交叉验证
randData.stratify(folds);
产生折叠
单步运行
接下来我们要做的事情就是创建训练和测试集
for (int n = 0; n < folds; n++) {
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// further processing, classification, etc.
...
}
Instances train = randData.trainCV(folds, n, rand);
上面的例子演示了单步运行交叉验证,如果你想运行10次10折交叉验证使用下面的循环:
Instances data = ...; // our dataset again, obtained from somewhere
int runs = 10;
for (int i = 0; i < runs; i++) {
seed = i+1; // every run gets a new, but defined seed value
// see: randomize the data
...
// see: generate the folds
...
}
下面给出几个具体的例子
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.Utils;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import java.util.Random;
/**
* Performs a single run of cross-validation.
*
* Command-line parameters:
* <ul>
* <li>-t filename - the dataset to use</li>
* <li>-x int - the number of folds to use</li>
* <li>-s int - the seed for the random number generator</li>
* <li>-c int - the class index, "first" and "last" are accepted as well;
* "last" is used by default</li>
* <li>-W classifier - classname and options, enclosed by double quotes;
* the classifier to cross-validate</li>
* </ul>
*
* Example command-line:
* <pre>
* java CrossValidationSingleRun -t anneal.arff -c last -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25"
* </pre>
*
* @author FracPete (fracpete at waikato dot ac dot nz)
*/
public class CrossValidationSingleRun {
/**
* Performs the cross-validation. See Javadoc of class for information
* on command-line parameters.
*
* @param args the command-line parameters
* @throws Excecption if something goes wrong
*/
public static void main(String[] args) throws Exception {
// loads data and set class index
Instances data = DataSource.read(Utils.getOption("t", args));
String clsIndex = Utils.getOption("c", args);
if (clsIndex.length() == 0)
clsIndex = "last";
if (clsIndex.equals("first"))
data.setClassIndex(0);
else if (clsIndex.equals("last"))
data.setClassIndex(data.numAttributes() - 1);
else
data.setClassIndex(Integer.parseInt(clsIndex) - 1);
// classifier
String[] tmpOptions;
String classname;
tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
classname = tmpOptions[0];
tmpOptions[0] = "";
Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);
// other options
int seed = Integer.parseInt(Utils.getOption("s", args));
int folds = Integer.parseInt(Utils.getOption("x", args));
// randomize data
Random rand = new Random(seed);
Instances randData = new Instances(data);
randData.randomize(rand);
if (randData.classAttribute().isNominal())
randData.stratify(folds);
// perform cross-validation
Evaluation eval = new Evaluation(randData);
for (int n = 0; n < folds; n++) {
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// the above code is used by the StratifiedRemoveFolds filter, the
// code below by the Explorer/Experimenter:
// Instances train = randData.trainCV(folds, n, rand);
// build and evaluate classifier
Classifier clsCopy = Classifier.makeCopy(cls);
clsCopy.buildClassifier(train);
eval.evaluateModel(clsCopy, test);
}
// output evaluation
System.out.println();
System.out.println("=== Setup ===");
System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()));
System.out.println("Dataset: " + data.relationName());
System.out.println("Folds: " + folds);
System.out.println("Seed: " + seed);
System.out.println();
System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
}
}
CrossValidationSingleRunVariant.java
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.Utils;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import java.util.Random;
/**
* Performs a single run of cross-validation. Outputs the Confusion matrices
* for each single fold.
*
* Command-line parameters:
* <ul>
* <li>-t filename - the dataset to use</li>
* <li>-x int - the number of folds to use</li>
* <li>-s int - the seed for the random number generator</li>
* <li>-c int - the class index, "first" and "last" are accepted as well;
* "last" is used by default</li>
* <li>-W classifier - classname and options, enclosed by double quotes;
* the classifier to cross-validate</li>
* </ul>
*
* Example command-line:
* <pre>
* java CrossValidationSingleRun -t anneal.arff -c last -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25"
* </pre>
*
* @author FracPete (fracpete at waikato dot ac dot nz)
*/
public class CrossValidationSingleRunVariant {
/**
* Performs the cross-validation. See Javadoc of class for information
* on command-line parameters.
*
* @param args the command-line parameters
* @throws Excecption if something goes wrong
*/
public static void main(String[] args) throws Exception {
// loads data and set class index
Instances data = DataSource.read(Utils.getOption("t", args));
String clsIndex = Utils.getOption("c", args);
if (clsIndex.length() == 0)
clsIndex = "last";
if (clsIndex.equals("first"))
data.setClassIndex(0);
else if (clsIndex.equals("last"))
data.setClassIndex(data.numAttributes() - 1);
else
data.setClassIndex(Integer.parseInt(clsIndex) - 1);
// classifier
String[] tmpOptions;
String classname;
tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
classname = tmpOptions[0];
tmpOptions[0] = "";
Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);
// other options
int seed = Integer.parseInt(Utils.getOption("s", args));
int folds = Integer.parseInt(Utils.getOption("x", args));
// randomize data
Random rand = new Random(seed);
Instances randData = new Instances(data);
randData.randomize(rand);
if (randData.classAttribute().isNominal())
randData.stratify(folds);
// perform cross-validation
System.out.println();
System.out.println("=== Setup ===");
System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()));
System.out.println("Dataset: " + data.relationName());
System.out.println("Folds: " + folds);
System.out.println("Seed: " + seed);
System.out.println();
Evaluation evalAll = new Evaluation(randData);
for (int n = 0; n < folds; n++) {
Evaluation eval = new Evaluation(randData);
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// the above code is used by the StratifiedRemoveFolds filter, the
// code below by the Explorer/Experimenter:
// Instances train = randData.trainCV(folds, n, rand);
// build and evaluate classifier
Classifier clsCopy = Classifier.makeCopy(cls);
clsCopy.buildClassifier(train);
eval.evaluateModel(clsCopy, test);
evalAll.evaluateModel(clsCopy, test);
// output evaluation
System.out.println();
System.out.println(eval.toMatrixString("=== Confusion matrix for fold " + (n+1) + "/" + folds + " ===\n"));
}
// output evaluation
System.out.println();
System.out.println(evalAll.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
}
}
CrossValidationMultipleRuns.java
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.Utils;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import java.util.Random;
/**
* Performs multiple runs of cross-validation.
*
* Command-line parameters:
* <ul>
* <li>-t filename - the dataset to use</li>
* <li>-x int - the number of folds to use</li>
* <li>-r int - the number of runs to perform</li>
* <li>-c int - the class index, "first" and "last" are accepted as well;
* "last" is used by default</li>
* <li>-W classifier - classname and options, enclosed by double quotes;
* the classifier to cross-validate</li>
* </ul>
*
* Example command-line:
* <pre>
* java CrossValidationMultipleRuns -t labor.arff -c last -x 10 -r 10 -W "weka.classifiers.trees.J48 -C 0.25"
* </pre>
*
* @author FracPete (fracpete at waikato dot ac dot nz)
*/
public class CrossValidationMultipleRuns {
/**
* Performs the cross-validation. See Javadoc of class for information
* on command-line parameters.
*
* @param args the command-line parameters
* @throws Excecption if something goes wrong
*/
public static void main(String[] args) throws Exception {
// loads data and set class index
Instances data = DataSource.read(Utils.getOption("t", args));
String clsIndex = Utils.getOption("c", args);
if (clsIndex.length() == 0)
clsIndex = "last";
if (clsIndex.equals("first"))
data.setClassIndex(0);
else if (clsIndex.equals("last"))
data.setClassIndex(data.numAttributes() - 1);
else
data.setClassIndex(Integer.parseInt(clsIndex) - 1);
// classifier
String[] tmpOptions;
String classname;
tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
classname = tmpOptions[0];
tmpOptions[0] = "";
Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);
// other options
int runs = Integer.parseInt(Utils.getOption("r", args));
int folds = Integer.parseInt(Utils.getOption("x", args));
// perform cross-validation
for (int i = 0; i < runs; i++) {
// randomize data
int seed = i + 1;
Random rand = new Random(seed);
Instances randData = new Instances(data);
randData.randomize(rand);
if (randData.classAttribute().isNominal())
randData.stratify(folds);
Evaluation eval = new Evaluation(randData);
for (int n = 0; n < folds; n++) {
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// the above code is used by the StratifiedRemoveFolds filter, the
// code below by the Explorer/Experimenter:
// Instances train = randData.trainCV(folds, n, rand);
// build and evaluate classifier
Classifier clsCopy = Classifier.makeCopy(cls);
clsCopy.buildClassifier(train);
eval.evaluateModel(clsCopy, test);
}
// output evaluation
System.out.println();
System.out.println("=== Setup run " + (i+1) + " ===");
System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()));
System.out.println("Dataset: " + data.relationName());
System.out.println("Folds: " + folds);
System.out.println("Seed: " + seed);
System.out.println();
System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));
}
}
}
CrossValidationAddPrediction.java
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.converters.ConverterUtils.DataSink;
import weka.core.Utils;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AddClassification;
import java.util.Random;
/**
* Performs a single run of cross-validation and adds the prediction on the
* test set to the dataset.
*
* Command-line parameters:
* <ul>
* <li>-t filename - the dataset to use</li>
* <li>-o filename - the output file to store dataset with the predictions
* in</li>
* <li>-x int - the number of folds to use</li>
* <li>-s int - the seed for the random number generator</li>
* <li>-c int - the class index, "first" and "last" are accepted as well;
* "last" is used by default</li>
* <li>-W classifier - classname and options, enclosed by double quotes;
* the classifier to cross-validate</li>
* </ul>
*
* Example command-line:
* <pre>
* java CrossValidationAddPrediction -t anneal.arff -c last -o predictions.arff -x 10 -s 1 -W "weka.classifiers.trees.J48 -C 0.25"
* </pre>
*
* @author FracPete (fracpete at waikato dot ac dot nz)
*/
public class CrossValidationAddPrediction {
/**
* Performs the cross-validation. See Javadoc of class for information
* on command-line parameters.
*
* @param args the command-line parameters
* @throws Excecption if something goes wrong
*/
public static void main(String[] args) throws Exception {
// loads data and set class index
Instances data = DataSource.read(Utils.getOption("t", args));
String clsIndex = Utils.getOption("c", args);
if (clsIndex.length() == 0)
clsIndex = "last";
if (clsIndex.equals("first"))
data.setClassIndex(0);
else if (clsIndex.equals("last"))
data.setClassIndex(data.numAttributes() - 1);
else
data.setClassIndex(Integer.parseInt(clsIndex) - 1);
// classifier
String[] tmpOptions;
String classname;
tmpOptions = Utils.splitOptions(Utils.getOption("W", args));
classname = tmpOptions[0];
tmpOptions[0] = "";
Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);
// other options
int seed = Integer.parseInt(Utils.getOption("s", args));
int folds = Integer.parseInt(Utils.getOption("x", args));
// randomize data
Random rand = new Random(seed);
Instances randData = new Instances(data);
randData.randomize(rand);
if (randData.classAttribute().isNominal())
randData.stratify(folds);
// perform cross-validation and add predictions
Instances predictedData = null;
Evaluation eval = new Evaluation(randData);
for (int n = 0; n < folds; n++) {
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// the above code is used by the StratifiedRemoveFolds filter, the
// code below by the Explorer/Experimenter:
// Instances train = randData.trainCV(folds, n, rand);
// build and evaluate classifier
Classifier clsCopy = Classifier.makeCopy(cls);
clsCopy.buildClassifier(train);
eval.evaluateModel(clsCopy, test);
// add predictions
AddClassification filter = new AddClassification();
filter.setClassifier(cls);
filter.setOutputClassification(true);
filter.setOutputDistribution(true);
filter.setOutputErrorFlag(true);
filter.setInputFormat(train);
Filter.useFilter(train, filter); // trains the classifier
Instances pred = Filter.useFilter(test, filter); // perform predictions on test set
if (predictedData == null)
predictedData = new Instances(pred, 0);
for (int j = 0; j < pred.numInstances(); j++)
predictedData.add(pred.instance(j));
}
// output evaluation
System.out.println();
System.out.println("=== Setup ===");
System.out.println("Classifier: " + cls.getClass().getName() + " " + Utils.joinOptions(cls.getOptions()));
System.out.println("Dataset: " + data.relationName());
System.out.println("Folds: " + folds);
System.out.println("Seed: " + seed);
System.out.println();
System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===", false));
// output "enriched" dataset
DataSink.write(Utils.getOption("o", args), predictedData);
}
}