头歌 Spark的机器学习-MLlib

第1关 MLlib介绍

package com.educoder.bigData.sparksql5;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Test1 {
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().appName("test1").master("local").getOrCreate();
List<Row> trainingList = Arrays.asList(
RowFactory.create(1.0, "a b c d E spark"),
RowFactory.create(0.0, "b d"),
RowFactory.create(1.0, "hadoop Mapreduce"),
RowFactory.create(0.0, "f g h"));
List<Row> testList = Arrays.asList(
RowFactory.create(0.0, "spark I j k"),
RowFactory.create(0.0, "l M n"),
RowFactory.create(0.0, "f g"),
RowFactory.create(0.0, "apache hadoop")
);
/********* Begin *********/
StructType schema = new StructType(
new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("text", DataTypes.StringType, false, Metadata.empty()) });
Dataset<Row> training = spark.createDataFrame(trainingList, schema);
Dataset<Row> test = spark.createDataFrame(testList, schema);
Tokenizer tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words");
HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features");
LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.001);
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { tokenizer, hashingTF, lr });
PipelineModel fit = pipeline.fit(training);
fit.transform(test).select("prediction").show();
/********* End *********/
}
}

第2关 MLlib-垃圾邮件检测

package com.educoder.bigData.sparksql5;
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.Word2Vec;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Case2 {
public static PipelineModel training(SparkSession spark) {
/********* Begin *********/
JavaRDD<Row> map = spark.read().textFile("SMSSpamCollection").toJavaRDD()
.map(String -> String.split(" ")).map(new Function<String[], Row>() {
@Override
public Row call(String[] v1) throws Exception {
String[] copyOfRange = Arrays.copyOfRange(v1, 1, v1.length);
Row create = RowFactory.create(v1[0], copyOfRange);
return create;
}
});
StructType schema = new StructType(new StructField[] {
new StructField("label", DataTypes.StringType, false, Metadata.empty()),
new StructField("message", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) });
Dataset<Row> data = spark.createDataFrame(map, schema);
StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel");
Word2Vec word2Vec = new Word2Vec().setInputCol("message").setOutputCol("features");
GBTClassifier mlpc = new GBTClassifier().setLabelCol("indexedLabel")
.setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] { labelIndexer, word2Vec, mlpc});
PipelineModel fit = pipeline.fit(data);
/********* End *********/
return fit;
}
}

第3关 MLlib-红酒分类预测

package com.educoder.bigData.sparksql5;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.*;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class Case3 {
public static PipelineModel training(SparkSession spark) {
/********* Begin *********/
JavaRDD<Row> javaRDD = spark.read().csv("dataset.csv").toJavaRDD();
JavaRDD<Row> map = javaRDD.map(new Function<Row, Row>() {
@Override
public Row call(Row v1) throws Exception {
int size = v1.size();
// 第一列为标签
double labelDouble = Double.parseDouble(v1.get(0).toString());
// 获取特征数组
double[] features = new double[size -1];
for(int n = 1 ; n < size ; n++) {
features[n-1] = Double.parseDouble(v1.get(n).toString());
}
// 创建 row
Row create = RowFactory.create(labelDouble,Vectors.dense(features));
return create;
}
});
StructType schema = new StructType(
new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty()) });
Dataset<Row> createDataFrame = spark.createDataFrame(map, schema);
RandomForestClassifier mlpc = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] { mlpc});
PipelineModel fit = pipeline.fit(createDataFrame);
/********* End *********/
return fit;
}
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值