Spark基于LogisticRegression逻辑回归实现英文垃圾邮件分类(Java版/Scala版)

12 篇文章 1 订阅
10 篇文章 0 订阅
这篇博客介绍了如何使用Spark的LogisticRegression在Java和Scala版本中实现英文垃圾邮件的分类。文中提到数据集的格式,并提供了相关代码示例。
摘要由CSDN通过智能技术生成

Spark基于LogisticRegression逻辑回归实现英文垃圾邮件分类(Java版/Scala版)

中英文邮件分词有所不同,若需要中文垃圾邮件分类请移步我上一篇博客Spark基于NaiveBayes朴素贝叶斯算法实现中文垃圾邮件分类实战(Java / Scala)

此次测试的数据集大概格式如下:

spam	You'll not rcv any more msgs from the chat svc. For FREE Hardcore services text GO to: 69988 If u get nothing u must Age Verify with yr network & try again
ham	Got c... I lazy to type... I forgot ü in lect... I saw a pouch but like not v nice...
ham	K, text me when you're on the way
ham	Sir, Waiting for your mail.
ham	A swt thought: "Nver get tired of doing little things 4 lovable persons.." Coz..somtimes those little things occupy d biggest part in their Hearts.. Gud ni8
ham	I know you are. Can you pls open the back?
ham	Yes see ya not on the dot
ham	Whats the staff name who is taking class for us?
spam	FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end
ham	Ummma.will call after check in.our life will begin from qatar so pls pray very hard.
ham	K..i deleted my contact that why?
ham	Sindu got job in birla soft ..
ham	The wine is flowing and i'm i have nevering..
ham	Yup i thk cine is better cos no need 2 go down 2 plaz

直接上代码,分为两个语言版本,Java版本与Scala版本。

一、Java版

  • EnglishSpamJava.java
package top.it1002.spark.ml.EnglistSpam;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;

import java.io.File;
import java.util.ArrayList;

/**
* @Author      王磊
* @Date        2018/12/17
* @ClassName   EnglishSpamJava
* @Description Java版英文垃圾邮件分类
**/
public class EnglishSpamJava {
   public static void main(String[] args) throws Exception{
       SparkConf conf = new SparkConf().setMaster("local[5]").setAppName("EnglishSpam");
       JavaSparkContext jsc = new JavaSparkContext(conf);
       SparkSession session = SparkSession
               .builder()
               .config(conf)
               .getOrCreate();
       JavaRDD<String> lines = jsc.textFile("C:\\Users\\asus\\Desktop\\data\\spam_ham.txt");
       RDD<Row> rowRDD = lines.map(new Function<String, Row>() {
           public Row call(String v1) throws Exception {
               String[] arr = v1.split("\t");
               double type = arr[0].equals("ham") ? 0.0 : 1.0;
               return RowFactory.create(type,arr[1]);
           }
       }).rdd();
       ArrayList<StructField> fields = new ArrayList<StructField>();
       fields.add(DataTypes.createStructField("label", DataTypes.DoubleType,true));
       fields.add(DataTypes.createStructField("content", DataTypes.StringType,true));
       StructType schema = DataTypes.createStructType(fields);
       Dataset<Row> data = session.createDataFrame(rowRDD, schema);
       // 分词器
       Tokenizer tkzer = new Tokenizer().setInputCol("content").setOutputCol("word");
       // 哈希桶词频分组
       HashingTF HTF = new HashingTF().setNumFeatures(1000).setInputCol("word").setOutputCol("features");
       // 逻辑回归算法
       LogisticRegression lRegress = new LogisticRegression().setMaxIter(20).setRegParam(0.1);
       // 管道器
       PipelineStage[] pp = new PipelineStage[3];
       pp[0] =  tkzer;
       pp[1] =  HTF;
       pp[2] =  lRegress;
       Pipeline pip = new Pipeline().setStages(pp);
       // 拟合数据,产生模型
       PipelineModel model = pip.fit(data);
       File file = new File("C:\\Users\\asus\\Desktop\\data\\email\\english\\model");
       if(file.list().length == 1){
           model.save("C:\\Users\\asus\\Desktop\\data\\email\\english\\model\\java");
       }

       // 模拟产生数据集进行测试
       ArrayList<Row> testRowList = new ArrayList<Row>();
       testRowList.add(RowFactory.create("I know you are. Can you pls open the back?"));
       testRowList.add(RowFactory.create("FreeMsg Why haven't you replied to my text? I'm Randy, sexy, female and live local. Luv to hear from u. Netcollex Ltd 08700621170150p per msg reply Stop to end"));
       testRowList.add(RowFactory.create("Hello, my love. What are you doing? Did you get to that interview today? Are you you happy? Are you being a good boy? Do you think of me?Are you missing me ?"));
       ArrayList<StructField> fields1 = new ArrayList<StructField>();
       fields1.add(DataTypes.createStructField("content",DataTypes.StringType,true));
       StructType schema1 = DataTypes.createStructType(fields1);
       Dataset<Row> testSet = session.createDataFrame(testRowList, schema1);
       Dataset<Row> predict = model.transform(testSet);
       predict.show();
       predict.createTempView("res");
       session.sql("select content,prediction from res").show();
   }
}

二、Scala版

  • EnglishSpamScala.scala
package top.it1002.spark.ml.EnglistSpam

import java.io.File
import java.util

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.{SparkConf, SparkContext}

/**
 * @Author       王磊
 * @Date         2018/12/17
 * @ClassName    EnglishSpamScala
 * @Description  Scala版英文垃圾邮件分类
 **/

object EnglishSpamScala {
 def main(args: Array[String]): Unit = {
   val conf = new SparkConf().setMaster("local[5]").setAppName("EnglishSpam")
   val context = new SparkContext(conf)
   val sess = SparkSession.builder().config(conf).getOrCreate()
   // 读取数据集
   val lines = context.textFile("C:\\Users\\asus\\Desktop\\data\\spam_ham.txt")
   // 处理准备数据
   val rowRDD = lines.map{
     line =>
       val arr = line.split("\t")
       val types = if(arr(0) == "ham") 0.0 else 1.0
       Row(types, arr(1))
   }
   val fields = Array[StructField](
     new StructField("label", DataTypes.DoubleType,true),
     new StructField("content", DataTypes.StringType,true)
   )
   val schema = new StructType(fields)
   val data = sess.createDataFrame(rowRDD,schema)
   // 分割数据
   // 分词器
   val words = new Tokenizer().setInputCol("content").setOutputCol("words")
   // 将分词装入hash桶中,统计词频,转化为特征向量数据
   val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features")
   // 创建逻辑算法对象,用以训练hash桶向量数据
   val logisticRegression = new LogisticRegression().setMaxIter(20).setRegParam(0.1)
   // 管道对象,将三个步骤拟合连接
   val pip = new Pipeline().setStages(Array(words, hashingTF, logisticRegression))
   // 拟和数据,产生模型
   val model = pip.fit(data)
   // 持久化模型
   val modelPath = "C:\\Users\\asus\\Desktop\\data\\email\\english\\model"
   val file = new File(modelPath)
   if(file.list().length == 0) model.save(modelPath + "\\scala")
   // 测试
   val testList = new util.ArrayList[String]()
   testList.add("Get me out of this dump heap. My mom decided to come to lowes. BORING.")
   testList.add("REMINDER FROM O2: To get 2.50 pounds free call credit and details of great offers pls reply 2 this text with your valid name, house no and postcode")
   testList.add("Cool, what time you think you can get here?")
   import sess.implicits._
   val testSet =sess.createDataset(testList).withColumnRenamed("value", "content")
   val predict = model.transform(testSet)
   predict.createTempView("res")
   sess.sql("select content,prediction from res").show()
 }
}

后续会持续更新Spark mllib相关小例子

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值