Spark实现简单的垃圾邮件分类–JAVA源码
代码部分
package cn.cc.spark;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.feature.HashingTF;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
public final class Spam {
public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("垃圾邮件训练").setMaster("local[*]");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> spam = sc.textFile("hdfs://localhost:8020/sample/spam/spmsga1.eml");
JavaRDD<String> mail = sc.textFile("hdfs://localhost:8020/sample/mail/3-1msg1.txt");
final HashingTF tf = new HashingTF(10000);
JavaRDD<LabeledPoint> positiveData = setLabeledPoint(spam, tf, 1.0);
JavaRDD<LabeledPoint> negativeData = setLabeledPoint(mail, tf, 0.0);
JavaRDD<LabeledPoint> trainingData = positiveData.union(negativeData);
trainingData.cache();
LogisticRegressionWithSGD lrLearner = new LogisticRegressionWithSGD();
LogisticRegressionModel model = lrLearner.run(trainingData.rdd());
Vector positiveSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/spam/spmsga2.eml"));
predictionResult("spmsga2.eml", model.predict(positiveSample));
Vector negativeSample = tf.transform(filterText(sc, "hdfs://localhost:8020/sample/mail/3-1msg2.txt"));
predictionResult("3-1msg2.txt", model.predict(negativeSample));
sc.stop();
sc.close();
}
public static void predictionResult(String emailName,double predictValue) {
if (predictValue==1.0) {
Logger.getGlobal().info(emailName+"邮件的预测结果为垃圾邮件"+predictValue);
}else {
Logger.getGlobal().info(emailName+"预测结果为正常邮件"+predictValue);
}
}
public static JavaRDD<LabeledPoint> setLabeledPoint(JavaRDD<String> rdd,HashingTF tf,Double lable ){
return rdd.map(new Function<String, LabeledPoint>() {
private static final long serialVersionUID = 1L;
@Override public LabeledPoint call(String email) {
List<String> list = new LinkedList<>();
Pattern pattern = Pattern.compile("[\\w]*");
Matcher matcher = pattern.matcher(email);
while(matcher.find()) {
list.add(matcher.group());
}
return new LabeledPoint(lable, tf.transform(list));
}
});
}
public static List<String> filterText(JavaSparkContext sc,String path){
List<String> list = new LinkedList<>();
List<String> data = sc.textFile(path).collect();
Iterator<String> iter = data.iterator();
while(iter.hasNext()) {
String str = iter.next();
if (str.trim().isEmpty()) {
iter.remove();
}else {
list.addAll(Arrays.asList(str.split(" ")));
}
}
return list;
}
}
相关依赖
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.10</artifactId>
<version>1.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>com.datastax.spark</groupId>
<artifactId>spark-cassandra-connector_2.10</artifactId>
<version>1.0.0-rc5</version>
</dependency>
<dependency>
<groupId>com.datastax.spark</groupId>
<artifactId>spark-cassandra-connector-java_2.10</artifactId>
<version>1.0.0-rc5</version>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch-hadoop-mr</artifactId>
<version>2.0.0.RC1</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-client</artifactId>
<version>8.1.14.v20131031</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.3.3</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.0</version>
</dependency>
<dependency>
<groupId>net.sf.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>2.0</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
<version>2.2.1</version>
</dependency>
</dependencies>