通常,在网上找到的mahout的naive bayes的例子跟官网的例子,都是针对20 newsgroup. 而且通常是命令行版本。虽然能得出预测、分类结果,但是对于Bayes具体是如何工作,以及如何处理自己的数据会比较茫然。
在努力了差不多一个星期之后,终于有点成果。
这个例子就是使用mahout 0.9 对kddcup 1999 的数据进行分析。
第一步: 下载数据。
地址: http://kdd.ics.uci.edu/databases/kddcup99/
关于数据的一些简单的预处理,我们会在第二步进行。细心的你可能发现,有些数据是2007年上传的!这是因为有一些数据原来的标记有错误,后来进行了更正。
第二步: 将原始文件转换成Hadoop使用的sequence 文件。
我们从官网知道,Bayes在mahout之中只有基于map-reduce的实现。 参考: https://mahout.apache.org/users/basics/algorithms.html 所以我们必须要将csv文件转换成hadoop使用的sequence文件
先贴一下代码:(注意:这里列的代码,仅仅用于说明流程,并没有注意性能方面的考虑。处理过大的文件的时候,需要有针对性的自行进行调整~)
package experiment.kdd99_bayes;
import java.io.FileReader;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import au.com.bytecode.opencsv.CSVReader;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
public class Kdd99CsvToSeqFile {
private String csvPath;
private Path seqPath;
private SequenceFile.Writer writer;
private Configuration conf = new Configuration();
private Map<String, Long> word2LongMap = Maps.newHashMap();
private List<String> strLabelList = Lists.newArrayList();
private FileSystem fs = null;
public Kdd99CsvToSeqFile(String csvFilePath, String seqPath) {
this.csvPath = csvFilePath;
this.seqPath = new Path(seqPath);
}
public Map<String, Long> getWordMap() {
return word2LongMap;
}
public List<String> getLabelList() {
return strLabelList;
}
/**
* Show out the already sequenced file content
*/
public void dump() {
try {
fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, this.seqPath, conf);
Text key = new Text();
VectorWritable value = new VectorWritable();
while (reader.next(key, value)) {
System.out.println( "reading key:" + key.toString() +" with value " +
value.toString());
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
fs.close();
fs = null;
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* Sequence target csv file.
* @param labelIndex
* @param hasHeader
*/
public void parse(int labelIndex, boolean hasHeader) {
CSVReader reader = null;
try {
fs = FileSystem.getLocal(conf);
if(fs.exists(this.seqPath))
fs.delete(this.seqPath, true);
writer = SequenceFile.createWriter(fs, conf, this.seqPath, Text.class, VectorWritable.class);
reader = new CSVReader(new FileReader(this.csvPath));
String[] header = null;
if(hasHeader) header = reader.readNext();
String[] line = null;
Long l = 0L;
while((line = reader.readNext()) != null) {
if(labelIndex > line.length) break;
l++;
List<String> tmpList = Lists.newArrayList(line);
String label = tmpList.get(labelIndex);
if(!strLabelList.contains(label)) strLabelList.add(label);
// Text key = new Text("/" + label + "/" + l);
Text key = new Text("/" + label + "/");
tmpList.remove(labelIndex);
VectorWritable vectorWritable = new VectorWritable();
Vector vector = new RandomAccessSparseVector(tmpList.size(), tmpList.size());//???
for(int i = 0; i < tmpList.size(); i++) {
String tmpStr = tmpList.get(i);
if(StringUtils.isNumeric(tmpStr))
vector.set(i, Double.parseDouble(tmpStr));
else
vector.set(i, parseStrCell(tmpStr));
}
vectorWritable.set(vector);
writer.append(key, vectorWritable);
}
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
fs.close();
fs = null;
writer.close();
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
private Long parseStrCell(String str) {
Long id = word2LongMap.get(str);
if( id == null) {
id = (long) (word2LongMap.size() + 1);
word2LongMap.put(str, id);
}
return id;
}
}
说明一下这个代码的工作流程:
1. 初始化hadoop,比如Configuration 、 FileSystem。
2. 通过Hadoop的 Sequence.Writer进行sequence文件的写入。其中的key/value 分别是Text 跟VectorWritable类型。
3. 通过CSVReader读入CSV文件,然后逐行遍历。如果是带标题的,则先略过第一行。
4. 对于每一行,将Array转成List方便操作。将label列从list之中删除~
5. 对于sequencefile, key为label + row number, 并且,需要以"/"作为开头,否则在实际运行的时候会提示找不到key!
6. 对于sequencefile的value,使用一个Vector进行数据承载。在此使用的是RandomAccessSparseVector,可以试着使用DenseVector进行测试,看看是否在性能上会有所改善。
在用Bayes试过了好几种数据之后,感觉对于Bayes,最关键的一步其实是在这里,因为选择那些feature、原始数据如何预处理就在这里进行了,剩下的都是模板一样的代码~ 即使命令行也一样。
第三步: 训练Bayes
在这里仅仅先贴出训练部分的代码,整体的代码最后上传
public static void train() throws Throwable {
System.out.println("~~~ begin to train ~~~");
Configuration conf = new Configuration();
FileSystem fs = FileSystem.getLocal(conf);
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
String outputDirectory = "/home/hadoop/DataSet/kdd99/bayes/output";
String tempDirectory = "/home/hadoop/DataSet/kdd99/bayes/temp";
fs.delete(new Path(outputDirectory),true);
fs.delete(new Path(tempDirectory),true);
// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c
trainNaiveBayes.run(new String[] {
"--input", trainSeqFile,
"--output", outputDirectory,
"-el",
"--labelIndex", "labelIndex",
"--overwrite",
"--tempDir", tempDirectory });
// Train the classifier
naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDirectory), conf);
System.out.println("features: " + naiveBayesModel.numFeatures());
System.out.println("labels: " + naiveBayesModel.numLabels());
}
从上面的代码可以看到,熟悉命令行之后,在实际java代码编写的时候,传入进去的也是一些命令行参数。
(可能有其他方法,只是目前我还不了解~)
命令行:
// cmd sample: mahout trainnb -i train-vectors -el -li labelindex -o model -ow -c
Java代码:
trainNaiveBayes.run
最后一步: 使用测试数据进行性能验证。
public static void test() throws IOException {
System.out.println("~~~ begin to test ~~~");
AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
CSVReader csv = new CSVReader(new FileReader(testFile));
csv.readNext(); // skip header
String[] line = null;
double totalSampleCount = 0.;
double correctClsCount = 0.;
while((line = csv.readNext()) != null) {
totalSampleCount ++;
Vector vector = new RandomAccessSparseVector(40,40);//???
for(int i = 0; i < 40; i++) {
if(StringUtils.isNumeric(line[i])) {
vector.set(i, Double.parseDouble(line[i]));
} else {
Long id = strOptionMap.get(line[i]);
if(id != null)
vector.set(i, id);
else {
System.out.println(StringUtils.join(line, ","));
continue;
}
}
}
Vector resultVector = classifier.classifyFull(vector);
int classifyResult = resultVector.maxValueIndex();
if(StringUtils.equals(line[41], strLabelList.get(classifyResult))) {
correctClsCount++;
} else {
System.out.println("Correct=" + line[41] + "\tClassify=" + strLabelList.get(classifyResult) );
}
}
System.out.println("Correct Ratio:" + (correctClsCount / totalSampleCount)); }
可以看到上面的加粗部分,用的是ComplementaryNaiveBayesClassifier,另外一个贝叶斯分类器就是
StandardNaiveBayesClassifier
最后运算的结果不太好,仅有约63%的正确率~
大家可以参考下面使用Bayes对Tweet进行分类的例子,正确率能有98%这样!当然,需要各位有过功夫网的本领了~
PS: 全部java代码已经在附件之中,感兴趣的还请自取~