Spark MLlib实现的中文文本分类–Naive Bayes

转载 2016年06月01日 15:43:43

文本分类是指将一篇文章归到事先定义好的某一类或者某几类,在数据平台的一个典型的应用场景是,通过爬取用户浏览过的页面内容,识别出用户的浏览偏好,从而丰富该用户的画像。
本文介绍使用Spark MLlib提供的朴素贝叶斯(Naive Bayes)算法,完成对中文文本的分类过程。主要包括中文分词、文本表示(TF-IDF)、模型训练、分类预测等。

中文分词

对于中文文本分类而言,需要先对文章进行分词,我使用的是IKAnalyzer中文分析工具,,其中自己可以配置扩展词库来使分词结果更合理,我从搜狗、百度输入法下载了细胞词库,将其作为扩展词库。这里不再介绍分词。

中文词语特征值转换(TF-IDF)

分好词后,每一个词都作为一个特征,但需要将中文词语转换成Double型来表示,通常使用该词语的TF-IDF值作为特征值,Spark提供了全面的特征抽取及转换的API,非常方便,详见http://spark.apache.org/docs/latest/ml-features.html,这里介绍下TF-IDF的API:

比如,训练语料/tmp/lxw1234/1.txt:

0,苹果 官网 苹果 宣布
1,苹果 梨 香蕉

逗号分隔的第一列为分类编号,0为科技,1为水果。

  1. case class RawDataRecord(category: String, text: String)
  2.  
  3. val conf = new SparkConf().setMaster("yarn-client")
  4. val sc = new SparkContext(conf)
  5. val sqlContext = new org.apache.spark.sql.SQLContext(sc)
  6. import sqlContext.implicits._
  7.  
  8. //将原始数据映射到DataFrame中,字段category为分类编号,字段text为分好的词,以空格分隔
  9. var srcDF = sc.textFile("/tmp/lxw1234/1.txt").map {
  10. x =>
  11. var data = x.split(",")
  12. RawDataRecord(data(0),data(1))
  13. }.toDF()
  14.  
  15. srcDF.select("category", "text").take(2).foreach(println)
  16. [0,苹果 官网 苹果 宣布]
  17. [1,苹果 香蕉]
  18. //将分好的词转换为数组
  19. var tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
  20. var wordsData = tokenizer.transform(srcDF)
  21.  
  22. wordsData.select($"category",$"text",$"words").take(2).foreach(println)
  23. [0,苹果 官网 苹果 宣布,WrappedArray(苹果, 官网, 苹果, 宣布)]
  24. [1,苹果 香蕉,WrappedArray(苹果, 梨, 香蕉)]
  25.  
  26. //将每个词转换成Int型,并计算其在文档中的词频(TF)
  27. var hashingTF =
  28. new HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(100)
  29. var featurizedData = hashingTF.transform(wordsData)
  30.  

这里将中文词语转换成INT型的Hashing算法,类似于Bloomfilter,上面的setNumFeatures(100)表示将Hash分桶的数量设置为100个,这个值默认为2的20次方,即1048576,可以根据你的词语数量来调整,一般来说,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确,但需要消耗更大的内存,和Bloomfilter是一个道理。

  1. featurizedData.select($"category", $"words", $"rawFeatures").take(2).foreach(println)
  2. [0,WrappedArray(苹果, 官网, 苹果, 宣布),(100,[23,81,96],[2.0,1.0,1.0])]
  3. [1,WrappedArray(苹果, 梨, 香蕉),(100,[23,72,92],[1.0,1.0,1.0])]

结果中,“苹果”用23来表示,第一个文档中,词频为2,第二个文档中词频为1.

  1. //计算TF-IDF值
  2. var idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
  3. var idfModel = idf.fit(featurizedData)
  4. var rescaledData = idfModel.transform(featurizedData)
  5. rescaledData.select($"category", $"words", $"features").take(2).foreach(println)
  6.  
  7. [0,WrappedArray(苹果, 官网, 苹果, 宣布),(100,[23,81,96],[0.0,0.4054651081081644,0.4054651081081644])]
  8. [1,WrappedArray(苹果, 梨, 香蕉),(100,[23,72,92],[0.0,0.4054651081081644,0.4054651081081644])]
  9.  
  10. //因为一共只有两个文档,且都出现了“苹果”,因此该词的TF-IDF值为0.
  11.  

最后一步,将上面的数据转换成Bayes算法需要的格式,如:

https://github.com/apache/spark/blob/branch-1.5/data/mllib/sample_naive_bayes_data.txt

  1. var trainDataRdd = rescaledData.select($"category",$"features").map {
  2. case Row(label: String, features: Vector) =>
  3. LabeledPoint(label.toDouble, Vectors.dense(features.toArray))
  4. }


每一个LabeledPoint中,特征数组的长度为100(setNumFeatures(100)),”官网”和”宣布”对应的特征索引号分别为81和96,因此,在特征数组中,第81位和第96位分别为它们的TF-IDF值。

到此,中文词语特征表示的工作已经完成,trainDataRdd已经可以作为Bayes算法的输入了。

分类模型训练

训练模型,语料非常重要,我这里使用的是搜狗提供的分类语料库,很早之前的了,这里只作为学习测试使用。

下载地址在:http://www.sogou.com/labs/dl/c.html,语料库一共有10个分类:

C000007 汽车
      C000008 财经
      C000010  IT
      C000013 健康
      C000014 体育
      C000016 旅游
      C000020 教育
      C000022 招聘
      C000023 文化
      C000024 军事

每个分类下有几千个文档,这里将这些语料进行分词,然后每一个分类生成一个文件,在该文件中,每一行数据表示一个文档的分词结果,重新用0-9作为这10个分类的编号:
0 汽车
1 财经
2 IT
3 健康
4 体育
5 旅游
6 教育
7 招聘
8 文化
9 军事

比如,汽车分类下的文件内容为:


数据准备好了,接下来进行模型训练及分类预测,代码:

  1. package com.lxw1234.textclassification
  2.  
  3. import scala.reflect.runtime.universe
  4.  
  5. import org.apache.spark.SparkConf
  6. import org.apache.spark.SparkContext
  7. import org.apache.spark.ml.feature.HashingTF
  8. import org.apache.spark.ml.feature.IDF
  9. import org.apache.spark.ml.feature.Tokenizer
  10. import org.apache.spark.mllib.classification.NaiveBayes
  11. import org.apache.spark.mllib.linalg.Vector
  12. import org.apache.spark.mllib.linalg.Vectors
  13. import org.apache.spark.mllib.regression.LabeledPoint
  14. import org.apache.spark.sql.Row
  15.  
  16.  
  17. object TestNaiveBayes {
  18. case class RawDataRecord(category: String, text: String)
  19. def main(args : Array[String]) {
  20. val conf = new SparkConf().setMaster("yarn-client")
  21. val sc = new SparkContext(conf)
  22. val sqlContext = new org.apache.spark.sql.SQLContext(sc)
  23. import sqlContext.implicits._
  24. var srcRDD = sc.textFile("/tmp/lxw1234/sougou/").map {
  25. x =>
  26. var data = x.split(",")
  27. RawDataRecord(data(0),data(1))
  28. }
  29. //70%作为训练数据,30%作为测试数据
  30. val splits = srcRDD.randomSplit(Array(0.7, 0.3))
  31. var trainingDF = splits(0).toDF()
  32. var testDF = splits(1).toDF()
  33. //将词语转换成数组
  34. var tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
  35. var wordsData = tokenizer.transform(trainingDF)
  36. println("output1:")
  37. wordsData.select($"category",$"text",$"words").take(1)
  38. //计算每个词在文档中的词频
  39. var hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("words").setOutputCol("rawFeatures")
  40. var featurizedData = hashingTF.transform(wordsData)
  41. println("output2:")
  42. featurizedData.select($"category", $"words", $"rawFeatures").take(1)
  43. //计算每个词的TF-IDF
  44. var idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
  45. var idfModel = idf.fit(featurizedData)
  46. var rescaledData = idfModel.transform(featurizedData)
  47. println("output3:")
  48. rescaledData.select($"category", $"features").take(1)
  49. //转换成Bayes的输入格式
  50. var trainDataRdd = rescaledData.select($"category",$"features").map {
  51. case Row(label: String, features: Vector) =>
  52. LabeledPoint(label.toDouble, Vectors.dense(features.toArray))
  53. }
  54. println("output4:")
  55. trainDataRdd.take(1)
  56. //训练模型
  57. val model = NaiveBayes.train(trainDataRdd, lambda = 1.0, modelType = "multinomial")
  58. //测试数据集,做同样的特征表示及格式转换
  59. var testwordsData = tokenizer.transform(testDF)
  60. var testfeaturizedData = hashingTF.transform(testwordsData)
  61. var testrescaledData = idfModel.transform(testfeaturizedData)
  62. var testDataRdd = testrescaledData.select($"category",$"features").map {
  63. case Row(label: String, features: Vector) =>
  64. LabeledPoint(label.toDouble, Vectors.dense(features.toArray))
  65. }
  66. //对测试数据集使用训练模型进行分类预测
  67. val testpredictionAndLabel = testDataRdd.map(p => (model.predict(p.features), p.label))
  68. //统计分类准确率
  69. var testaccuracy = 1.0 * testpredictionAndLabel.filter(x => x._1 == x._2).count() / testDataRdd.count()
  70. println("output5:")
  71. println(testaccuracy)
  72. }
  73. }

执行后,主要输出如下:

output1:(将词语转换成数组)


output2:(计算每个词在文档中的词频)


output3:(计算每个词的TF-IDF)


output4:(Bayes算法的输入数据格式)


output5:(测试数据集分类准确率)


准确率90%,还可以。接下来需要收集分类更细,时间更新的数据来训练和测试了。。

更新:

程序中使用的/tmp/lxw1234/sougou/目录下的文件提供下载:

链接: http://pan.baidu.com/s/1ntUyI9N 密码: i5nj


转载请注明:lxw的大数据田地 » Spark MLlib实现的中文文本分类–Naive Bayes

举报

相关文章推荐

混杂设备驱动程序

混杂设备驱动程序是那些简单的字符驱动程序,它们拥有一些相同的特性。内核将这些共同行抽象至一个API中(具体实现代码见 drivers/char/misc),这些简化了驱动程序的初始化的方式。所有的混杂...

固态继电器原理及应用电路

固态继电器(SOLIDSTATE RELAYS),简写成“SSR”,是一种全部由固态电子元件组成的新型无触点开关器件,它利用电子元件(如开关三极管、双向可控硅等半导体器件)的开关特性,可达到无触点无火...

我是如何成为一名python大咖的?

人生苦短,都说必须python,那么我分享下我是如何从小白成为Python资深开发者的吧。2014年我大学刚毕业..

Spark RDD 到 LabelPoint的转换(包含构造临时数据的方法)

题目: 将数据的某个特征作为label, 其他特征(或其他某几个特征)作为Feature, 转为LabelPoint参考: http://www.it1352.com/220642.html 首先构...

Spark性能优化:资源调优篇

在开发完Spark作业之后,就该为作业配置合适的资源了。Spark的资源参数,基本都可以在spark-submit命令中作为参数设置。很多Spark初学者,通常不知道该设置哪些必要的参数,以及如何设置...

Spark性能优化:开发调优篇

前言 在大数据计算领域,Spark已经成为了越来越流行、越来越受欢迎的计算平台之一。Spark的功能涵盖了大数据领域的离线批处理、SQL类处理、流式/实时计算、机器学习、图计算等各种不同类型的计算操作...

Spark机器学习API之特征处理(一)

Spark机器学习库中包含了两种实现方式,一种是spark.mllib,这种是基础的API,基于RDDs之上构建,另一种是spark.ml,这种是higher-level API,基于DataFram...

Spark性能优化:shuffle调优

shuffle调优 调优概述       大多数Spark作业的性能主要就是消耗在了shuffle环节,因为该环节包含了大量的磁盘IO、序列化、网络数据传输等操作。因此,如果要让作业的性能更上一层...
返回顶部
收藏助手
不良信息举报
您举报文章:深度学习:神经网络中的前向传播和反向传播算法推导
举报原因:
原因补充:

(最多只允许输入30个字)