package org.apache.mahout.classifier.bayes;
public final class TrainClassifier
bayes和cbyes的入口类
两个分支
先设定所有默认参数,如果有非默认项再覆盖
由于参数过多,定义一个类封装参数,便于后续传递
package org.apache.mahout.classifier.bayes;
public final class TestClassifier
入口
分并行和非并行两种实现
public final class TrainClassifier
bayes和cbyes的入口类
两个分支
public static void trainNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException {
BayesDriver driver = new BayesDriver();
driver.runJob(dir, outputDir, params);
}
public static void trainCNaiveBayes(Path dir, Path outputDir, BayesParameters params) throws IOException {
CBayesDriver driver = new CBayesDriver();
driver.runJob(dir, outputDir, params);
}
先设定所有默认参数,如果有非默认项再覆盖
由于参数过多,定义一个类封装参数,便于后续传递
BayesParameters params = new BayesParameters();
// Setting all the default parameter values
params.setGramSize(1);
params.setMinDF(1);
params.set("alpha_i","1.0");
params.set("dataSource", "hdfs");
if (cmdLine.hasOption(gramSizeOpt)) {
params.setGramSize(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
}
if (cmdLine.hasOption(minDfOpt)) {
params.setMinDF(Integer.parseInt((String) cmdLine.getValue(minDfOpt)));
}
Path inputPath = new Path((String) cmdLine.getValue(inputDirOpt));
Path outputPath = new Path((String) cmdLine.getValue(outputOpt));
if ("cbayes".equalsIgnoreCase(classifierType)) {
log.info("Training Complementary Bayes Classifier");
trainCNaiveBayes(inputPath, outputPath, params);
} else {
log.info("Training Bayes Classifier");
// setup the HDFS and copy the files there, then run the trainer
trainNaiveBayes(inputPath, outputPath, params);
}
package org.apache.mahout.classifier.bayes;
public final class TestClassifier
入口
public static void classifyParallel(BayesParameters params) throws IOException {
BayesClassifierDriver.runJob(params);
}
分并行和非并行两种实现
if ("sequential".equalsIgnoreCase(classificationMethod)) {
classifySequential(params);
} else if ("mapreduce".equalsIgnoreCase(classificationMethod)) {
classifyParallel(params);
}