之前写的代码都是单机上跑的,发现现在很流行hadoop,所以又试着用hadoop mapreduce来处理下决策树的创建。因为hadoop接触的也不多,所以写的不好,勿怪。
看了一些mahout在处理决策树和随机森林的过程,大体过程是Job只有一个Mapper处理,在map方法里面做数据的转换收集工作,然后在cleanup方法里面去做决策树的创建过程。然后将决策树序列化到HDFS上面,分类样本数据集的时候,在从HDFS上面取回决策树结构。大体来说,mahout决策树的构建过程好像并没有结合分布式计算,因为我也并没有仔仔细细的去研读mahout里面的源码,所以可能是我没发现。下面是我实现的一个简单hadoop版本决策树,用的C4.5算法,通过MapReduce去计算增益率。最后生成的决策树并未保存在HDFS上面,后面有时间在考虑下吧。下面是具体代码实现:
public class DecisionTreeC45Job extends AbstractJob {
/** 对数据集做准备工作,主要就是将填充好默认值的数据集再次传到HDFS上*/
public String prepare(Data trainData) {
String path = FileUtils.obtainRandomTxtPath();
DataHandler.writeData(path, trainData);
System.out.println(path);
String name = path.substring(path.lastIndexOf(File.separator) + 1);
String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;
HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);
return hdfsPath;
}
/** 选择最佳属性,读取MapReduce计算后产生的文件,取增益率最大*/
public AttributeGainWritable chooseBestAttribute(String output) {
AttributeGainWritable maxAttribute = null;
Path path = new Path(output);
try {
FileSystem fs = path.getFileSystem(conf);
Path[] paths = HDFSUtils.getPathFiles(fs, path);
ShowUtils.print(paths);
double maxGainRatio = 0.0;
SequenceFile.Reader reader = null;
for (Path p : paths) {
reader = new SequenceFile.Reader(fs, p, conf);
Text key = (Text) ReflectionUtils.newInstance(
reader.getKeyClass(), conf);
AttributeGainWritable value = new AttributeGainWritable();
while (reader.next(key, value)) {
double gainRatio = value.getGainRatio();
if (gainRatio >= maxGainRatio) {
maxGainRatio = gainRatio;
maxAttribute = value;
}
value = new AttributeGainWritable();
}
IOUtils.closeQuietly(reader);
}
System.out.println("output: " + path.toString());
HDFSUtils.delete(conf, path);
System.out.println("hdfs delete file : " + path.toString());
} catch (IOException e) {
e.printStackTrace();
}
return maxAttribute;
}
/** 构造决策树 */
public Object build(String input, Data data) {
Object preHandleResult = preHandle(data);
if (null != preHandleResult) return preHandleResult;
String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL;
HDFSUtils.delete(conf, new Path(output));
System.out.println("delete output path : " + output);
String[] paths = new String[]{input, output};
//通过MapReduce计算增益率
CalculateC45GainRatioMR.main(paths);
AttributeGainWritable bestAttr = chooseBestAttribute(output);
String attribute = bestAttr.getAttribute();
System.out.println("best attribute: " + attribute);
System.out.println("isCategory: " + bestAttr.isCategory());
if (bestAttr.isCategory()) {
return attribute;
}
String[] splitPoints = bestAttr.obtainSplitPoints();
System.out.print("splitPoints: ");
ShowUtils.print(splitPoints);
TreeNode treeNode = new TreeNode(attribute);
String[] attributes = data.getAttributesExcept(attribute);
//分割数据集,并将分割后的数据集传到HDFS上
DataSplit dataSplit = DataHandler.split(new Data(
data.getInstances(), attribute, splitPoints));
for (DataSplitItem item : dataSplit.getItems()) {
String path = item.getPath();
String name = path.substring(path.lastIndexOf(File.separator) + 1);
String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;
HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);
treeNode.setChild(item.getSplitPoint(), build(hdfsPath,
new Data(attributes, item.getInstances())));
}
return treeNode;
}
/** 分类,根据决策树节点判断测试样本集的类型,并将结果上传到HDFS上*/
private void classify(TreeNode treeNode, String trainSet, String testSet, String output) {
OutputStream out = null;
BufferedWriter writer = null;
try {
Path trainSetPath = new Path(trainSet);
FileSystem trainFS = trainSetPath.getFileSystem(conf);
Path[] trainHdfsPaths = HDFSUtils.getPathFiles(trainFS, trainSetPath);
FSDataInputStream trainFSInputStream = trainFS.open(trainHdfsPaths[0]);
Data trainData = DataLoader.load(trainFSInputStream, true);
Path testSetPath = new Path(testSet);
FileSystem testFS = testSetPath.getFileSystem(conf);
Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath);
FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]);
Data testData = DataLoader.load(fsInputStream, true);
DataHandler.fill(testData.getInstances(), trainData.getAttributes(), 0);
Object[] results = (Object[]) treeNode.classify(testData);
ShowUtils.print(results);
DataError dataError = new DataError(testData.getCategories(), results);
dataError.report();
String path = FileUtils.obtainRandomTxtPath();
out = new FileOutputStream(new File(path));
writer = new BufferedWriter(new OutputStreamWriter(out));
StringBuilder sb = null;
for (int i = 0, len = results.length; i < len; i++) {
sb = new StringBuilder();
sb.append(i+1).append("\t").append(results[i]);
writer.write(sb.toString());
writer.newLine();
}
writer.flush();
Path outputPath = new Path(output);
FileSystem fs = outputPath.getFileSystem(conf);
if (!fs.exists(outputPath)) {
fs.mkdirs(outputPath);
}
String name = path.substring(path.lastIndexOf(File.separator) + 1);
HDFSUtils.copyFromLocalFile(conf, path, output +
File.separator + name);
} catch (IOException e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(out);
IOUtils.closeQuietly(writer);
}
}
public void run(String[] args) {
try {
if (null == conf) conf = new Configuration();
String[] inputArgs = new GenericOptionsParser(
conf, args).getRemainingArgs();
if (inputArgs.length != 3) {
System.out.println("error, please input three path.");
System.out.println("1. trainset path.");
System.out.println("2. testset path.");
System.out.println("3. result output path.");
System.exit(2);
}
Path input = new Path(inputArgs[0]);
FileSystem fs = input.getFileSystem(conf);
Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, input);
FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]);
Data trainData = DataLoader.load(fsInputStream, true);
/** 填充缺失属性的默认值*/
DataHandler.fill(trainData, 0);
String hdfsInput = prepare(trainData);
TreeNode treeNode = (TreeNode) build(hdfsInput, trainData);
TreeNodeHelper.print(treeNode, 0, null);
classify(treeNode, inputArgs[0], inputArgs[1], inputArgs[2]);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) {
DecisionTreeC45Job job = new DecisionTreeC45Job();
long startTime = System.currentTimeMillis();
job.run(args);
long endTime = System.currentTimeMillis();
System.out.println("spend time: " + (endTime - startTime));
}
}
CalculateC45GainRatioMR具体实现:
public class CalculateC45GainRatioMR {
private static void configureJob(Job job) {
job.setJarByClass(CalculateC45GainRatioMR.class);
job.setMapperClass(CalculateC45GainRatioMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(AttributeWritable.class);
job.setReducerClass(CalculateC45GainRatioReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(AttributeGainWritable.class);
job.setInputFormatClass(TextInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
}
public static void main(String[] args) {
Configuration configuration = new Configuration();
try {
String[] inputArgs = new GenericOptionsParser(
configuration, args).getRemainingArgs();
if (inputArgs.length != 2) {
System.out.println("error, please input two path. input and output");
System.exit(2);
}
Job job = new Job(configuration, "Decision Tree");
FileInputFormat.setInputPaths(job, new Path(inputArgs[0]));
FileOutputFormat.setOutputPath(job, new Path(inputArgs[1]));
configureJob(job);
System.out.println(job.waitForCompletion(true) ? 0 : 1);
} catch (Exception e) {
e.printStackTrace();
}
}
}
class CalculateC45GainRatioMapper extends Mapper<LongWritable, Text,
Text, AttributeWritable> {
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
}
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
String line = value.toString();
StringTokenizer tokenizer = new StringTokenizer(line);
Long id = Long.parseLong(tokenizer.nextToken());
String category = tokenizer.nextToken();
boolean isCategory = true;
while (tokenizer.hasMoreTokens()) {
isCategory = false;
String attribute = tokenizer.nextToken();
String[] entry = attribute.split(":");
context.write(new Text(entry[0]), new AttributeWritable(id, category, entry[1]));
}
if (isCategory) {
context.write(new Text(category), new AttributeWritable(id, category, category));
}
}
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
super.cleanup(context);
}
}
class CalculateC45GainRatioReducer extends Reducer<Text, AttributeWritable, Text, AttributeGainWritable> {
@Override
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
}
@Override
protected void reduce(Text key, Iterable<AttributeWritable> values,
Context context) throws IOException, InterruptedException {
String attributeName = key.toString();
double totalNum = 0.0;
Map<String, Map<String, Integer>> attrValueSplits =
new HashMap<String, Map<String, Integer>>();
Iterator<AttributeWritable> iterator = values.iterator();
boolean isCategory = false;
while (iterator.hasNext()) {
AttributeWritable attribute = iterator.next();
String attributeValue = attribute.getAttributeValue();
if (attributeName.equals(attributeValue)) {
isCategory = true;
break;
}
Map<String, Integer> attrValueSplit = attrValueSplits.get(attributeValue);
if (null == attrValueSplit) {
attrValueSplit = new HashMap<String, Integer>();
attrValueSplits.put(attributeValue, attrValueSplit);
}
String category = attribute.getCategory();
Integer categoryNum = attrValueSplit.get(category);
attrValueSplit.put(category, null == categoryNum ? 1 : categoryNum + 1);
totalNum++;
}
if (isCategory) {
System.out.println("is Category");
int sum = 0;
iterator = values.iterator();
while (iterator.hasNext()) {
iterator.next();
sum += 1;
}
System.out.println("sum: " + sum);
context.write(key, new AttributeGainWritable(attributeName,
sum, true, null));
} else {
double gainInfo = 0.0;
double splitInfo = 0.0;
for (Map<String, Integer> attrValueSplit : attrValueSplits.values()) {
double totalCategoryNum = 0;
for (Integer categoryNum : attrValueSplit.values()) {
totalCategoryNum += categoryNum;
}
double entropy = 0.0;
for (Integer categoryNum : attrValueSplit.values()) {
double p = categoryNum / totalCategoryNum;
entropy -= p * (Math.log(p) / Math.log(2));
}
double dj = totalCategoryNum / totalNum;
gainInfo += dj * entropy;
splitInfo -= dj * (Math.log(dj) / Math.log(2));
}
double gainRatio = splitInfo == 0.0 ? 0.0 : gainInfo / splitInfo;
StringBuilder splitPoints = new StringBuilder();
for (String attrValue : attrValueSplits.keySet()) {
splitPoints.append(attrValue).append(",");
}
splitPoints.deleteCharAt(splitPoints.length() - 1);
System.out.println("attribute: " + attributeName);
System.out.println("gainRatio: " + gainRatio);
System.out.println("splitPoints: " + splitPoints.toString());
context.write(key, new AttributeGainWritable(attributeName,
gainRatio, false, splitPoints.toString()));
}
}
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
super.cleanup(context);
}
}