第1关:MLlib介绍
package com.educoder.bigData.sparksql5;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.*;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Test1 {
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().appName("test1").master("local").getOrCreate();
List<Row> trainingList = Arrays.asList(
RowFactory.create(1.0, "a b c d E spark"),
RowFactory.create(0.0, "b d"),
RowFactory.create(1.0, "hadoop Mapreduce"),
RowFactory.create(0.0, "f g h"));
List<Row> testList = Arrays.asList(
RowFactory.create(0.0, "spark I j k"),
RowFactory.create(0.0, "l M n"),
RowFactory.create(0.0, "f g"),
RowFactory.create(0.0, "apache hadoop")
);
/********* Begin *********/
// 创建训练数据的schema
StructType schema = new StructType(new StructField[] {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("text", DataTypes.StringType, false, Metadata.empty())
});
// 创建训练数据和测试数据的DataFrame
Dataset<Row> training = spark.createDataFrame(trainingList, schema);
Dataset<Row> test = spark.createDataFrame(testList, schema);
// 创建分词器
Tokenizer tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words");
// 创建特征提取器
HashingTF hashingTF = new HashingTF()
.setNumFeatures(1000)
.setInputCol("words")
.setOutputCol("features");
// 创建逻辑回归分类器
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.001);
// 创建Pipeline
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
// 训练模型
PipelineModel model = pipeline.fit(training);
// 使用模型进行预测
Dataset<Row> predictions = model.transform(test);
// 显示预测结果
predictions.select("prediction").show();
/********* End *********/
}
}
第2关:MLlib-垃圾邮件检测
package com.educoder.bigData.sparksql5;
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.*;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.Word2Vec;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Case2 {
public static PipelineModel training(SparkSession spark) {
/********* Begin *********/
// 读取数据文件
JavaRDD<Row> map = spark.read().textFile("SMSSpamCollection").toJavaRDD()
.map(String -> String.split(" ")).map(new Function<String[], Row>() {
@Override
public Row call(String[] v1) throws Exception {
String[] copyOfRange = Arrays.copyOfRange(v1, 1, v1.length);
return RowFactory.create(v1[0], copyOfRange);
}
});
// 定义schema
StructType schema = new StructType(new StructField[] {
new StructField("label", DataTypes.StringType, false, Metadata.empty()),
new StructField("message", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
// 创建DataFrame
Dataset<Row> df = spark.createDataFrame(map, schema);
// 标签转换
StringIndexer labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel");
// Word2Vec特征转换
Word2Vec word2Vec = new Word2Vec()
.setInputCol("message")
.setOutputCol("features")
.setVectorSize(200)
.setMinCount(1)
.setWindowSize(5);
// 使用随机森林分类器
RandomForestClassifier rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
.setNumTrees(50)
.setMaxDepth(10)
.setMaxBins(32);
// 创建Pipeline
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {labelIndexer, word2Vec, rf});
// 训练模型
PipelineModel model = pipeline.fit(df);
/********* End *********/
return model;
}
}
第3关:MLlib-红酒分类预测
package com.educoder.bigData.sparksql5;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.*;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Case3 {
public static PipelineModel training(SparkSession spark) {
/********* Begin *********/
// 读取数据文件
JavaRDD<String> data = spark.sparkContext().textFile("dataset.csv", 1).toJavaRDD();
// 将数据转换为Row格式
JavaRDD<Row> rowRDD = data.map(new Function<String, Row>() {
@Override
public Row call(String line) throws Exception {
String[] parts = line.split(",");
double label = Double.parseDouble(parts[0]);
double[] features = new double[parts.length - 1];
for (int i = 1; i < parts.length; i++) {
features[i-1] = Double.parseDouble(parts[i]);
}
return RowFactory.create(label, Vectors.dense(features));
}
});
// 定义schema
StructType schema = new StructType(new StructField[] {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
// 创建DataFrame
Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
// 使用随机森林分类器
RandomForestClassifier rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(50)
.setMaxDepth(10)
.setMaxBins(32);
// 创建Pipeline
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {rf});
// 训练模型
PipelineModel model = pipeline.fit(df);
/********* End *********/
return model;
}
}