(1)数据
1,2011-01-01,1,0,1,0,0,6,0,1,0.24,0.2879,0.81,0,3,13,16
2,2011-01-01,1,0,1,1,0,6,0,1,0.22,0.2727,0.8,0,8,32,40
3,2011-01-01,1,0,1,2,0,6,0,1,0.22,0.2727,0.8,0,5,27,32
含义
instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt
(2)代码
public class HWDecisionTreeClass {
//【3--15】 为向量
//【16】为特征
private static class ParsePoint implements Function<String, LabeledPoint> {
private static final Pattern SPACE = Pattern.compile(",");
@Override
public LabeledPoint call(String line) {
String[] parts = line.split(",");
double[] v = new double[parts.length - 3];
for (int i = 0; i < parts.length - 3; i++)
v[i] = Double.parseDouble(parts[i + 2]);
return new LabeledPoint(Double.parseDouble(parts[16]), Vectors.dense(v));
}
}
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
// 加载与解析数据
String datapath = "hour.txt";
JavaRDD<String> lines = jsc.textFile(datapath);
JavaRDD<LabeledPoint> traindata = lines.map(new ParsePoint());
List<LabeledPoint> take = traindata.take(3);
for (LabeledPoint labeledPoint : take) {
System.out.println("----->" + labeledPoint.features());
System.out.println("----->" + labeledPoint.label());
}
// 70%的数据用于训练,30%的数据用于测试
JavaRDD<LabeledPoint>[] splits = traindata.randomSplit(new double[] { 0.9, 0.1 });
// 训练数据
JavaRDD<LabeledPoint> trainingData = splits[0];
// 测试数据
JavaRDD<LabeledPoint> testData = splits[1];
// 设置参数 ,空的categoricalFeaturesInfo表示所有功能都是连续的。
Integer numClasses = 1900;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 20;
Integer maxBins = 32;
// 训练DecisionTree模型进行分类。
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins);
// 使用模型进程预测,并和实际值比较
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
System.out.println(predictionAndLabel.take(10));
Double testErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / testData.count();
System.out.println("Test Error: -------------------------------------------------------------------" + testErr);
System.out.println("Learned classification tree model:\n-------------------------------------------"
+ model.toDebugString());
}
}