spark 使用lda算法提取中文文档文本主题

本篇文章的呢主要写的使用spark ml 中的lda算法提取文档的主题的方法思路,不牵扯到lda的 算法原理。至于算法请参照http://www.aboutyun.com/thread-20130-1-1.html 这篇文章

使用lda算法对中文文本聚类并提取主题,大体上需要这么几个过程:

1.首先采用中文分词工具对中文分词,这里采用开源的IK分词。

2.从分词之后的词表中去掉停用词,生成新的词表。

3.利用文档转向量的工具将文档转换为向量。

4.对向量使用lda算法运算,运算完成之后取出主题的详情,以及主题在文档中的分布详情。

具体代码如下:

public class IkAnalyzerTool{


	public String call(String line) throws Exception {
		  StringReader sr=new StringReader(line);  
	        IKSegmenter ik=new IKSegmenter(sr, true);  
	        Lexeme lex=null;  
	        StringBuffer sb = new StringBuffer();
	        while((lex=ik.next())!=null){  
	        	sb.append(lex.getLexemeText());
	        	sb.append(" ");
	        }  
	        return sb.toString();
	}
	public static void main(String[] args) throws Exception {
		IkAnalyzerTool a = new IkAnalyzerTool();
		System.out.println(a.call("我是中国人"));
	}
}
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.mutable.WrappedArray;

import com.googsoft.spark.ml.ik.IkAnalyzerTool;

public class MyCluster {

    public static void main(String[] agrs){

        //配置spark的初始文件
        SparkSession spark = SparkSession
        	      .builder()
        	      .appName("mylda")
        	      .getOrCreate();
        //加载初始数据
    JavaRDD<Tuple2<String, String>>  files= spark.sparkContext().wholeTextFiles("hdfs://mycluster/ml/edata", 1).toJavaRDD();
	List<Row> rows=  files.map(new Function<Tuple2<String,String>,Row>(){
		@Override
		public Row call(Tuple2<String, String> v1) throws Exception {
			IkAnalyzerTool it = new IkAnalyzerTool();
		    return RowFactory.create(v1._1,Arrays.asList(it.call(v1._2).split(" ")));
		}
	    	
	    }).collect();
		StructType schema = new StructType(new StructField[] { 
				new StructField(
						"fpath", DataTypes.StringType, false,
						Metadata.empty())
				,new StructField(
				"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())});
		Dataset<Row> documentDF = spark.createDataFrame(rows, schema);
		//将文本路径变为数字序号
		StringIndexer indexer = new StringIndexer()
		  .setInputCol("fpath")
		  .setOutputCol("docid");
		Dataset<Row> indexed = indexer.fit(documentDF).transform(documentDF);
		/**
		 * 过滤停用词,提高精准度
		 */
		String[] stopWords=(String[]) spark.read().textFile("hdfs://mycluster/ml/stopwords/chinese.txt").collect();
		StopWordsRemover remover = new StopWordsRemover()
		  .setInputCol("words")
		 .setOutputCol("filtered")
		 .setStopWords(stopWords);
		Dataset<Row> fitlered = remover.transform(indexed);
       //利用countvector 算法将过滤之后的词表转换为向量
		CountVectorizer cv = new CountVectorizer().setInputCol("filtered").setOutputCol("features");
        CountVectorizerModel cvmodel =cv.fit(fitlered);
        Dataset<Row> cvResult= cvmodel.transform(fitlered);
        //获得转成向量时词表
        final String vocabulary[] = cvmodel.vocabulary();
        //利用LDA算法训练,提取文本的主题
        LDA lda = new LDA().setK(5).setMaxIter(20);
        LDAModel ldaModel = lda.fit(cvResult);
        double ll = ldaModel.logLikelihood(cvResult);
        double lp = ldaModel.logPerplexity(cvResult);
        System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
        //LDA主题模型的评价指标是困惑度,困惑度越小,模型越好
        System.out.println("The upper bound bound on perplexity: " + lp);
        JavaRDD<Row> topics = ldaModel.describeTopics(30).toJavaRDD();
        List<Row> t1=topics.map(new Function<Row,Row>(){
			@Override
			public Row call(Row row) throws Exception {
				int topic =row.getAs(0);
				WrappedArray<Integer> terms = row.getAs(1);
				List<String> termsed = new ArrayList<String>();
				Iterator<Integer> it=terms.iterator();
				while(it.hasNext()){
					int indice=it.next();
					termsed.add(vocabulary[indice]);
				}
				WrappedArray<Double> termWeights=row.getAs(2); 
				return RowFactory.create(topic,termsed.toArray(),termWeights);
			}
        }).collect();
        //取出topic中的中文主题
        StructType topicschema = new StructType(new StructField[] { 
				new StructField(
						"topic", DataTypes.IntegerType, false,
						Metadata.empty())
				,new StructField(
				"terms", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()),
				new StructField(
						"termWeights ", DataTypes.createArrayType(DataTypes.DoubleType), false, Metadata.empty())
        });
		Dataset<Row> topicdatas = spark.createDataFrame(t1, topicschema);
		topicdatas.show(false);
        Dataset<Row> transformed = ldaModel.transform(cvResult);
        Dataset<Row> finalset=transformed.select("docid","topicDistribution");
        finalset.write().json("hdfs://mycluster/ml/result");
        spark.stop();

    }



阅读更多
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qqLK123/article/details/75676365
个人分类: spark
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

不良信息举报

spark 使用lda算法提取中文文档文本主题

最多只允许输入30个字

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭