文本分类是指将一篇文章归到事先定义好的某一类或者某几类,在数据平台的一个典型的应用场景是,通过爬取用户浏览过的页面内容,识别出用户的浏览偏好,从而丰富该用户的画像。
本文介绍使用Spark MLlib提供的朴素贝叶斯(Naive Bayes)算法,完成对中文文本的分类过程。主要包括中文分词、文本表示(TF-IDF)、模型训练、分类预测等。
特征工程
文本处理
对于中文文本分类,需要先对内容进行分词,我使用的是ansj中文分析工具,其中自己可以配置扩展词库来使分词结果更合理,同时可以加一些停用词可以提高准确率,需要把数据样本分割成两批数据,一份用于训练模型,一份用于测试模型效果。
代码
目录结构
DataFactory.java
package com.maweiming.spark.mllib.classifier;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 1、first step
* data format
* Created by Coder-Ma on 2017/6/12.
*/
public class DataFactory {
public static final String CLASS_PATH = "/Users/coderma/coders/github/SparkTextClassifier/src/main/resources";
public static final String STOP_WORD_PATH = CLASS_PATH + "/data/stopWord.txt";
public static final String NEWS_DATA_PATH = CLASS_PATH + "/data/NewsData";
public static final String DATA_TRAIN_PATH = CLASS_PATH + "/data/data-train.txt";
public static final String DATA_TEST_PATH = CLASS_PATH + "/data/data-test.txt";
public static final String MODELS = CLASS_PATH + "/models";
public static final String MODEL_PATH = CLASS_PATH + "/models/category-4";
public static final String LABEL_PATH = CLASS_PATH + "/models/labels.txt";
public static final String TF_PATH = CLASS_PATH + "/models/tf";
public static final String IDF_PATH = CLASS_PATH + "/models/idf";
public static void main(String[] args) {
/**
* 收集数据、特征工程
* 1、遍历数据样本目录
* 2、对数据进行清洗,剔除掉停用词
*/
//数据样本切割比例 80%用于训练样本,20%数据用于测试模型准确率
Double spiltRate = 0.8;
//停用词
List<String> stopWords = FileUtils.readLine(line -> line, STOP_WORD_PATH);
//分类标签(标签id,分类名)
Map<Integer, String> labels = new HashMap<>();
Integer dirIndex = 0;
String[] dirNames = new File(NEWS_DATA_PATH).list();
for (String dirName : dirNames) {
dirIndex++;
labels.put(dirIndex, dirName);
String fileDirPath = String.format("%s/%s", NEWS_DATA_PATH, dirName);
String[] fileNames = new File(fileDirPath).list();
//当前分类目录的样本总数 * 切割比率
int spilt = Double.valueOf(fileNames.length * spiltRate).intValue();
for (int i = 0; i < fileNames.length; i++) {
String fileName = fileNames[i];
String filePath = String.format("%s/%s", fileDirPath, fileName);
System.out.println(filePath);
String text = FileUtils.readFile(filePath);
for (String stopWord : stopWords) {
text = text.replaceAll(stopWord, "");
}
if (StringUtils.isBlank(text)) {
continue;
}
//把文本内容进行分词
List<String> wordList = AnsjUtils.participle(text);
JSONObject data = new JSONObject();
data.put("text", wordList);
data.put("category", Double.valueOf(dirIndex));
if (i > spilt) {
//测试数据
FileUtils.appendText(DATA_TEST_PATH, data.toJSONString() + "\n");
} else {
//训练数据
FileUtils.appendText(DATA_TRAIN_PATH, data.toJSONString() + "\n");
}
}
}
FileUtils.writer(LABEL_PATH, JSON.toJSONString(labels));//data labels
System.out.println("Data processing successfully !");
System.out.println("=======================================================");
System.out.println("trainData:" + DATA_TRAIN_PATH);
System.out.println("testData:" + DATA_TEST_PATH);
System.out.println("labes:" + LABEL_PATH);
System.out.println("=======================================================");
}
}
训练模型
词语特征值处理(TF-IDF)
分好词后,每一个词都作为一个特征,需要将中文词语转换成Double型来表示,通常使用该词语的TF-IDF值作为特征值,Spark提供了全面的特征抽取及转换的API,非常方便,详见http://spark.apache.org/docs/latest/ml-features.html
为原始属于设置标签,按照resource->NewsData目录下面文件夹索引区分。
- car
- game
- it
- military
这里将中文词语转换成INT型的Hashing算法,类似于Bloomfilter,下面的setNumFeatures(500000)表示将Hash分桶的数量设置为500000个,这个值默认为2的20次方,即1048576,可以根据你的词语数量来调整,一般来说,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确,但需要消耗更大的内存,和Bloomfilter是一个道理。
然后就可以训练模型,下面代码
代码
package com.maweiming.spark.mllib.classifier;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import java.io.IOException;
/**
* 2、The second step
* Created by Coder-Ma on 2017/6/26.
*/
public class NaiveBayesTrain {
public static void main(String[] args) throws IOException {
//1、创建一个SparkSession
SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
.getOrCreate();
//2、加载训练数据样本
Dataset<Row> train = spark.read().json(DataFactory.DATA_TRAIN_PATH);
//3、通过tf-idf计算数据样本中的词频
//word frequency count
HashingTF hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("text").setOutputCol("rawFeatures");
Dataset<Row> featurizedData = hashingTF.transform(train);
//count tf-idf
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
Dataset<Row> rescaledData = idfModel.transform(featurizedData);
//4、把数据样本转换成向量
JavaRDD<LabeledPoint> trainDataRdd = rescaledData.select("category", "features").javaRDD().map(v1 -> {
Double category = v1.getAs("category");
SparseVector features = v1.getAs("features");
Vector featuresVector = Vectors.dense(features.toArray());
return new LabeledPoint(Double.valueOf(category),featuresVector);
});
System.out.println("Start training...");
//调用朴素贝叶斯算法,传入向量数据训练模型
NaiveBayesModel model = NaiveBayes.train(trainDataRdd.rdd());
//save model
model.save(spark.sparkContext(), DataFactory.MODEL_PATH);
//save tf
hashingTF.save(DataFactory.TF_PATH);
//save idf
idfModel.save(DataFactory.IDF_PATH);
System.out.println("train successfully !");
System.out.println("=======================================================");
System.out.println("modelPath:"+DataFactory.MODEL_PATH);
System.out.println("tfPath:"+DataFactory.TF_PATH);
System.out.println("idfPath:"+DataFactory.IDF_PATH);
System.outprintln("=======================================================");
}
}
训练模型完成
train successfully !
=======================================================
modelPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/category-4
tfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/tf
idfPath:/Users/coderma/coders/github/SparkTextClassifier/src/main/resources/models/idf
=======================================================
测试模型
package com.maweiming.spark.mllib.classifier;
import com.alibaba.fastjson.JSON;
import com.maweiming.spark.mllib.dto.Result;
import com.maweiming.spark.mllib.utils.AnsjUtils;
import com.maweiming.spark.mllib.utils.FileUtils;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.*;
import java.io.File;
import java.text.DecimalFormat;
import java.util.*;
/**
* 3、the third step
* Created by Coder-Ma on 2017/6/26.
*/
public class NaiveBayesTest {
private static HashingTF hashingTF;
private static IDFModel idfModel;
private static NaiveBayesModel model;
private static Map<Integer,String> labels = new HashMap<>();
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().appName("NaiveBayes").master("local")
.getOrCreate();
//load tf file
hashingTF = HashingTF.load(DataFactory.TF_PATH);
//load idf file
idfModel = IDFModel.load(DataFactory.IDF_PATH);
//load model
model = NaiveBayesModel.load(spark.sparkContext(), DataFactory.MODEL_PATH);
//batch test
batchTestModel(spark, DataFactory.DATA_TEST_PATH);
//test a single
testModel(spark,"最近这段时间,由于印度三哥可能有些膨胀,在边境地区总想“搞事情”,这也让不少人的目光集中到此。事实上,我国在与印度的交界处有一军事要地,只要解放军一抬高水位,那么印军或就“不战而退”。它就是地处我国西藏与印度控制克什米尔交界的班公湖。\n" +
"\n" +
"\n" +
"众所周知,从古至今那些地处与军事险要易守难攻的形胜之地,都具有非常重要的军事意义。经常能左右一场战争的胜负。据悉,班公湖位于西藏自治区阿里地区日土县城西北。全长有600多公里,其中地处中国的有400多公里,地处与印度约有200公里。整体成东西走向,海拔在4000多米以上。湖水整体为淡水湖,但由于湖水在西段的淡水补给量的大方面建少,东西方向上交替不通畅,使西部的区域变成了咸水湖。于是便出现了一个有趣的现象,在东部的中国境内班公湖为淡水湖,在西部的印度境内,班公湖为咸水湖。\n" +
"\n" +
"\n" +
"而我军在于印度交界的班公湖区域有一个阀门,这个区域有着非常大的军事作用,而如果印军将部队部署在班公湖地区,我军只需打开阀门,抬高班公湖的东部水位。将他们的军事设施和军用要道给全部淹没。而印军的军事物资和后勤保障都将全部瘫痪,到时印度的军事部署都将全部不攻自破。\n" +
"\n" +
"\n" +
"而印度应该知道现代战争最为重要的便是后勤制度的保障,军事行动能否取得胜利,很大程度取决于后勤能否及时的跟上。而我军在班公湖地区地势上就有了绝对的军事优势,军用物资也可源源不断的运输上来,而印度却优势全无。而我国自古以来就是爱好和平的国家,人不犯我我不犯人。只希望印军能认清与我国军事力量的差距,不要盲目自信。\n" +
"\n");
}
public static void batchTestModel(SparkSession sparkSession, String testPath) {
Dataset<Row> test = sparkSession.read().json(testPath);
//word frequency count
Dataset<Row> featurizedData = hashingTF.transform(test);
//count tf-idf
Dataset<Row> rescaledData = idfModel.transform(featurizedData);
List<Row> rowList = rescaledData.select("category", "features").javaRDD().collect();
List<Result> dataResults = new ArrayList<>();
for (Row row : rowList) {
Double category = row.getAs("category");
SparseVector sparseVector = row.getAs("features");
Vector features = Vectors.dense(sparseVector.toArray());
double predict = model.predict(features);
dataResults.add(new Result(category, predict));
}
Integer successNum = 0;
Integer errorNum = 0;
for (Result result : dataResults) {
if (result.isCorrect()) {
successNum++;
} else {
errorNum++;
}
}
DecimalFormat df = new DecimalFormat("######0.0000");
Double result = (Double.valueOf(successNum) / Double.valueOf(dataResults.size())) * 100;
System.out.println("batch test");
System.out.println("=======================================================");
System.out.println("Summary");
System.out.println("-------------------------------------------------------");
System.out.println(String.format("Correctly Classified Instances : %s\t %s%%",successNum,df.format(result)));
System.out.println(String.format("Incorrectly Classified Instances : %s\t %s%%",errorNum,df.format(100-result)));
System.out.println(String.format("Total Classified Instances : %s",dataResults.size()));
System.out.println("===================================");
}
public static void testModel(SparkSession sparkSession, String content){
List<Row> data = Arrays.asList(
RowFactory.create(AnsjUtils.participle(content))
);
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, false), false, Metadata.empty())
});
Dataset<Row> testData = sparkSession.createDataFrame(data, schema);
//word frequency count
Dataset<Row> transform = hashingTF.transform(testData);
//count tf-idf
Dataset<Row> rescaledData = idfModel.transform(transform);
Row row =rescaledData.select("features").first();
SparseVector sparseVector = row.getAs("features");
Vector features = Vectors.dense(sparseVector.toArray());
Double predict = model.predict(features);
System.out.println("test a single");
System.out.println("=======================================================");
System.out.println("Result");
System.out.println("-------------------------------------------------------");
System.out.println(labels.get(predict.intValue()));
System.out.println("===================================");
}
}
测试结果
batch test
=======================================================
Summary
-------------------------------------------------------
Correctly Classified Instances : 785 98.6181%
Incorrectly Classified Instances : 11 1.3819%
Total Classified Instances : 796
===================================
准确率98%,还可以。以上就是文本分类器的实现,我们还可以直接把数据样本换成 正常邮件|垃圾邮件 这两类的数据,就可以实现一个垃圾邮箱分类器了
源码
https://github.com/Maweiming/SparkTextClassifier