Spark基于LogisticRegression逻辑回归实现英文垃圾邮件分类(Java版/Scala版)
中英文邮件分词有所不同,若需要中文垃圾邮件分类请移步我上一篇博客Spark基于NaiveBayes朴素贝叶斯算法实现中文垃圾邮件分类实战(Java / Scala)
此次测试的数据集大概格式如下:
spam You'll not rcv any more msgs from the chat svc. For FREE Hardcore services text GO to: 69988 If u get nothing u must Age Verify with yr network & try again
ham Got c... I lazy to type... I forgot ü in lect... I saw a pouch but like not v nice...
ham K, text me when you're on the way
ham Sir, Waiting for your mail.
ham A swt thought: "Nver get tired of doing little things 4 lovable persons.." Coz..somtimes those little things occupy d biggest part in their Hearts.. Gud ni8
ham I know you are. Can you pls open the back?
ham Yes see ya not on the dot
ham Whats the staff name who is taking class for us?
spam FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end
ham Ummma.will call after check in.our life will begin from qatar so pls pray very hard.
ham K..i deleted my contact that why?
ham Sindu got job in birla soft ..
ham The wine is flowing and i'm i have nevering..
ham Yup i thk cine is better cos no need 2 go down 2 plaz
直接上代码,分为两个语言版本,Java版本与Scala版本。
一、Java版
- EnglishSpamJava.java
package top.it1002.spark.ml.EnglistSpam;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;
import java.io.File;
import java.util.ArrayList;
/**
* @Author 王磊
* @Date 2018/12/17
* @ClassName EnglishSpamJava
* @Description Java版英文垃圾邮件分类
**/
public class EnglishSpamJava {
public static void main(String[] args) throws Exception{
SparkConf conf = new SparkConf().setMaster("local[5]").setAppName("EnglishSpam");
JavaSparkContext jsc = new JavaSparkContext(conf);
SparkSession session = SparkSession
.builder()
.config(conf)
.getOrCreate();
JavaRDD<String> lines = jsc.textFile("C:\\Users\\asus\\Desktop\\data\\spam_ham.txt");
RDD<Row> rowRDD = lines.map(new Function<String, Row>() {
public Row call(String v1) throws Exception {
String[] arr = v1.split("\t");
double type = arr[0].equals("ham") ? 0.0 : 1.0;
return RowFactory.create(type,arr[1]);
}
}).rdd();
ArrayList<StructField> fields = new ArrayList<StructField>();
fields.add(DataTypes.createStructField("label", DataTypes.DoubleType,true));
fields.add(DataTypes.createStructField("content", DataTypes.StringType,true));
StructType schema = DataTypes.createStructType(fields);
Dataset<Row> data = session.createDataFrame(rowRDD, schema);
// 分词器
Tokenizer tkzer = new Tokenizer().setInputCol("content").setOutputCol("word");
// 哈希桶词频分组
HashingTF HTF = new HashingTF().setNumFeatures(1000).setInputCol("word").setOutputCol("features");
// 逻辑回归算法
LogisticRegression lRegress = new LogisticRegression().setMaxIter(20).setRegParam(0.1);
// 管道器
PipelineStage[] pp = new PipelineStage[3];
pp[0] = tkzer;
pp[1] = HTF;
pp[2] = lRegress;
Pipeline pip = new Pipeline().setStages(pp);
// 拟合数据,产生模型
PipelineModel model = pip.fit(data);
File file = new File("C:\\Users\\asus\\Desktop\\data\\email\\english\\model");
if(file.list().length == 1){
model.save("C:\\Users\\asus\\Desktop\\data\\email\\english\\model\\java");
}
// 模拟产生数据集进行测试
ArrayList<Row> testRowList = new ArrayList<Row>();
testRowList.add(RowFactory.create("I know you are. Can you pls open the back?"));
testRowList.add(RowFactory.create("FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end"));
testRowList.add(RowFactory.create("Hello, my love. What are you doing? Did you get to that interview today? Are you you happy? Are you being a good boy? Do you think of me?Are you missing me ?"));
ArrayList<StructField> fields1 = new ArrayList<StructField>();
fields1.add(DataTypes.createStructField("content",DataTypes.StringType,true));
StructType schema1 = DataTypes.createStructType(fields1);
Dataset<Row> testSet = session.createDataFrame(testRowList, schema1);
Dataset<Row> predict = model.transform(testSet);
predict.show();
predict.createTempView("res");
session.sql("select content,prediction from res").show();
}
}
二、Scala版
- EnglishSpamScala.scala
package top.it1002.spark.ml.EnglistSpam
import java.io.File
import java.util
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}
/**
* @Author 王磊
* @Date 2018/12/17
* @ClassName EnglishSpamScala
* @Description Scala版英文垃圾邮件分类
**/
object EnglishSpamScala {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[5]").setAppName("EnglishSpam")
val context = new SparkContext(conf)
val sess = SparkSession.builder().config(conf).getOrCreate()
// 读取数据集
val lines = context.textFile("C:\\Users\\asus\\Desktop\\data\\spam_ham.txt")
// 处理准备数据
val rowRDD = lines.map{
line =>
val arr = line.split("\t")
val types = if(arr(0) == "ham") 0.0 else 1.0
Row(types, arr(1))
}
val fields = Array[StructField](
new StructField("label", DataTypes.DoubleType,true),
new StructField("content", DataTypes.StringType,true)
)
val schema = new StructType(fields)
val data = sess.createDataFrame(rowRDD,schema)
// 分割数据
// 分词器
val words = new Tokenizer().setInputCol("content").setOutputCol("words")
// 将分词装入hash桶中,统计词频,转化为特征向量数据
val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features")
// 创建逻辑算法对象,用以训练hash桶向量数据
val logisticRegression = new LogisticRegression().setMaxIter(20).setRegParam(0.1)
// 管道对象,将三个步骤拟合连接
val pip = new Pipeline().setStages(Array(words, hashingTF, logisticRegression))
// 拟和数据,产生模型
val model = pip.fit(data)
// 持久化模型
val modelPath = "C:\\Users\\asus\\Desktop\\data\\email\\english\\model"
val file = new File(modelPath)
if(file.list().length == 0) model.save(modelPath + "\\scala")
// 测试
val testList = new util.ArrayList[String]()
testList.add("Get me out of this dump heap. My mom decided to come to lowes. BORING.")
testList.add("REMINDER FROM O2: To get 2.50 pounds free call credit and details of great offers pls reply 2 this text with your valid name, house no and postcode")
testList.add("Cool, what time you think you can get here?")
import sess.implicits._
val testSet =sess.createDataset(testList).withColumnRenamed("value", "content")
val predict = model.transform(testSet)
predict.createTempView("res")
sess.sql("select content,prediction from res").show()
}
}
后续会持续更新Spark mllib相关小例子