使用spark计算文档相似度

1、TF-IDF文档转换为向量

以下边三个句子为例

罗湖发布大梧桐新兴产业带整体规划
深化伙伴关系,增强发展动力
为世界经济发展贡献中国智慧

经过分词后变为

[罗湖, 发布, 大梧桐, 新兴产业, 带, 整体, 规划]|
[深化, 伙伴, 关系, 增强, 发展, 动力]
[为, 世界, 经济发展, 贡献, 中国, 智慧]

经过词频(TF)计算后,词频=某个词在文章中出现的次数

(262144,[10607,18037,52497,53469,105320,122761,220591],[1.0,1.0,1.0,1.0,1.0,1.0,1.0])
(262144,[8684,20809,154835,191088,208112,213540],[1.0,1.0,1.0,1.0,1.0,1.0]) 
(262144,[21159,30073,53529,60542,148594,197957],[1.0,1.0,1.0,1.0,1.0,1.0])  

262144为总词数,这个值越大,不同的词被计算为一个Hash值的概率就越小,数据也更准确。
[10607,18037,52497,53469,105320,122761,220591]分别代表罗湖, 发布, 大梧桐, 新兴产业, 带, 整体, 规划的向量值
[1.0,1.0,1.0,1.0,1.0,1.0,1.0]分别代表罗湖, 发布, 大梧桐, 新兴产业, 带, 整体, 规划在句子中出现的次数

经过逆文档频率(IDF),逆文档频率=log(总文章数/包含该词的文章数)

[6.062092444847088,7.766840537085513,7.073693356525568,5.201891179623976,7.073693356525568,5.3689452642871425,6.514077568590145]
[3.8750202389748862,5.464255444091467,6.062092444847088,7.3613754289773485,6.668228248417403,5.975081067857458]
[6.2627631403092385,4.822401557919072,6.2627631403092385,6.2627631403092385,3.547332831909406,4.065538562973019]

其中[6.062092444847088,7.766840537085513,7.073693356525568,5.201891179623976,7.073693356525568,5.3689452642871425,6.514077568590145]分别代表罗湖, 发布, 大梧桐, 新兴产业, 带, 整体, 规划的逆文档频率

2、相似度计算方法
在之前学习《Mahout实战》书中聚类算法中,知道几种相似性度量方法
欧氏距离测度
给定平面上的两个点,通过一个标尺来计算出它们之间的距离
penngo博客图片

平方欧氏距离测度
这种距离测度的值是欧氏距离的平方。

penngo博客图片

曼哈顿距离测度
两个点之间的距离是它们坐标差的绝对值

penngo博客图片

余弦距离测度
余弦距离测度需要我们将这些点视为人原点指向它们的向量,向量之间形成一个夹角,当夹角较小时,这些向量都会指向大致相同方向,因此这些点非常接近,当夹角非常小时,这个夹角的余弦接近于1,而随着角度变大,余弦值递减。
两个n维向量之间的余弦距离公式 

penngo博客图片

谷本距离测度
余弦距离测度忽略向量的长度,这适用于某些数据集,但是在其它情况下可能会导致糟糕的聚类结果,谷本距离表现点与点之间的夹角和相对距离信息。

penngo博客图片

加权距离测度
允许对不同的维度加权从而提高或减小某些维度对距离测度值的影响。

3、代码实现

spark ml有TF_IDF的算法实现,spark sql也能实现数据结果的轻松读取和排序,也自带有相关余弦值计算方法。本文将使用余弦相似度计算文档相似度,计算公式为

penngo
测试数据来源于12月07日-12月12日之间网上抓取,样本测试数据量为16632条,
数据格式为:Id@==@发布时间@==@标题@==@内容@==@来源。penngo_07_12.txt文件内容如下:

penngo博客图片

第一条新闻是这段时间的一个新闻热点,本文例子是计算所有新闻与第一条新闻的相似度,计算结果按相似度从高到低排序,最终结果保存在文本文件中。

使用maven创建项目spark项目

pom.xml配置

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>com.spark.penngo</groupId>
  <artifactId>spark_test</artifactId>
  <packaging>jar</packaging>
  <version>1.0-SNAPSHOT</version>
  <name>spark_test</name>
  <url>http://maven.apache.org</url>
  <dependencies>
    <dependency>
      <groupId>junit</groupId>
      <artifactId>junit</artifactId>
      <version>4.12</version>
      <scope>test</scope>
    </dependency>
	<dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>2.0.2</version>
    </dependency>
	<dependency>
	    <groupId>org.apache.spark</groupId>
		<artifactId>spark-sql_2.11</artifactId>
		<version>2.0.2</version>
	</dependency>
	<dependency>
		<groupId>org.apache.spark</groupId>
		<artifactId>spark-mllib_2.11</artifactId>
		<version>2.0.2</version>
	</dependency>
	<dependency>
        <groupId>org.apache.hadoop</groupId>
        <artifactId>hadoop-client</artifactId>
        <version>2.2.0</version>
    </dependency>
	<dependency>
		<groupId>org.lionsoul</groupId>
		<artifactId>jcseg-core</artifactId>
		<version>2.0.0</version>
	</dependency>

      <dependency>
          <groupId>commons-io</groupId>
          <artifactId>commons-io</artifactId>
          <version>2.5</version>
      </dependency>
      <!--
      <dependency>
          <groupId>org.mongodb</groupId>
          <artifactId>mongodb-driver</artifactId>
          <version>3.3.0</version>
      </dependency>
      <dependency>
          <groupId>org.jsoup</groupId>
          <artifactId>jsoup</artifactId>
          <version>1.10.1</version>
      </dependency>
      <dependency>
          <groupId>com.alibaba</groupId>
          <artifactId>fastjson</artifactId>
          <version>1.2.21</version>
      </dependency>
      -->
  </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.1</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                    <encoding>UTF-8</encoding>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

SimilarityTest.java

package com.spark.penngo.tfidf;

import com.spark.test.tfidf.util.SimilartyData;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.linalg.BLAS;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.*;
import org.lionsoul.jcseg.tokenizer.core.*;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.StringReader;
import java.util.*;

/**
 * 计算文档相似度https://my.oschina.net/penngo/blog
 */
public class SimilarityTest {
    private static SparkSession spark = null;
    private static String splitTag = "@==@";
    public static Dataset<Row> tfidf(Dataset<Row> dataset) {
        Tokenizer tokenizer = new Tokenizer().setInputCol("segment").setOutputCol("words");
        Dataset<Row> wordsData = tokenizer.transform(dataset);
        HashingTF hashingTF = new HashingTF()
                .setInputCol("words")
                .setOutputCol("rawFeatures");
        Dataset<Row> featurizedData = hashingTF.transform(wordsData);
        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
        IDFModel idfModel = idf.fit(featurizedData);
        Dataset<Row> rescaledData = idfModel.transform(featurizedData);
        return rescaledData;
    }

    public static Dataset<Row> readTxt(String dataPath) {
        JavaRDD<TfIdfData> newsInfoRDD = spark.read().textFile(dataPath).javaRDD().map(new Function<String, TfIdfData>() {
            private ISegment seg = null;
            private void initSegment() throws Exception {
                if (seg == null) {
                    JcsegTaskConfig config = new JcsegTaskConfig();
                    config.setLoadCJKPos(true);
                    String path = new File("").getAbsolutePath() + "/data/lexicon";
                    System.out.println(new File("").getAbsolutePath());
                    ADictionary dic = DictionaryFactory.createDefaultDictionary(config);
                    dic.loadDirectory(path);
                    seg = SegmentFactory.createJcseg(JcsegTaskConfig.COMPLEX_MODE, config, dic);
                }
            }

            public TfIdfData call(String line) throws Exception {
                initSegment();
                TfIdfData newsInfo = new TfIdfData();

                String[] lines = line.split(splitTag);
                if(lines.length < 5){
                    System.out.println("error==" + lines[0] + " " + lines[1]);
                }
                String id = lines[0];
                String publish_timestamp = lines[1];
                String title = lines[2];
                String content = lines[3];
                String source = lines.length >4 ? lines[4] : "" ;

                seg.reset(new StringReader(content));
                StringBuffer sff = new StringBuffer();
                IWord word = seg.next();
                while (word != null) {
                    sff.append(word.getValue()).append(" ");
                    word = seg.next();
                }
                newsInfo.setId(id);
                newsInfo.setTitle(title);
                newsInfo.setSegment(sff.toString());
                return newsInfo;
            }
        });
        Dataset<Row> dataset = spark.createDataFrame(
                newsInfoRDD,
                TfIdfData.class
        );
        return dataset;
    }
    public static SparkSession initSpark() {
        if (spark == null) {
            spark = SparkSession
                    .builder()
                    .appName("SimilarityPenngoTest").master("local[3]")
                    .getOrCreate();
        }
        return spark;
    }
    public static void similarDataset(String id, Dataset<Row> dataSet, String datePath) throws Exception{
        Row firstRow = dataSet.select("id", "title", "features").where("id ='" + id + "'").first();
        Vector firstFeatures = firstRow.getAs(2);

        Dataset<SimilartyData> similarDataset = dataSet.select("id", "title", "features").map(new MapFunction<Row, SimilartyData>(){
            public SimilartyData call(Row row) {
                String id = row.getString(0);
                String title = row.getString(1);
                Vector features = row.getAs(2);
                double dot = BLAS.dot(firstFeatures.toSparse(), features.toSparse());
                double v1 = Vectors.norm(firstFeatures.toSparse(), 2.0);
                double v2 = Vectors.norm(features.toSparse(), 2.0);
                double similarty = dot / (v1 * v2);
                SimilartyData similartyData = new SimilartyData();
                similartyData.setId(id);
                similartyData.setTitle(title);
                similartyData.setSimilarty(similarty);
                return similartyData;
            }
        }, Encoders.bean(SimilartyData.class));
        Dataset<Row> similarDataset2 = spark.createDataFrame(
                similarDataset.toJavaRDD(),
                SimilartyData.class
        );

        FileOutputStream out = new FileOutputStream(datePath);
        OutputStreamWriter osw = new OutputStreamWriter(out, "UTF-8");
        similarDataset2.select("id", "title", "similarty").sort(functions.desc("similarty")).collectAsList().forEach(row->{
            try{
                StringBuffer sff = new StringBuffer();
                String sid = row.getAs(0);
                String title = row.getAs(1);
                double similarty = row.getAs(2);
                sff.append(sid).append(" ").append(similarty).append(" ").append(title).append("\n");
                osw.write(sff.toString());
            }
            catch(Exception e){
                e.printStackTrace();
            }
        });
        osw.close();
        out.close();
    }
    public static void run() throws Exception{
        initSpark();
        String dataPath = new File("").getAbsolutePath() + "/data/penngo_07_12.txt";

        Dataset<Row> dataSet = readTxt(dataPath);
        dataSet.show();
        Dataset<Row> tfidfDataSet = tfidf(dataSet);
        String id = "58528946cc9434e17d8b4593";
        String similarFile = new File("").getAbsolutePath() + "/data/penngo_07_12_similar.txt";
        similarDataset(id, tfidfDataSet, similarFile);

    }

    public static void main(String[] args) throws Exception{
        //window上运行
        //System.setProperty("hadoop.home.dir", "D:/penngo/hadoop-2.6.4");
        //System.setProperty("HADOOP_USER_NAME", "root");
        run();
    }

}

运行结果,相似度越高的,新闻排在越前边,样例数据的测试结果基本满足要求。data_07_12_similar.txt文件内容如下:

penngo博客图片

参考资料

《Mahout实战》

TF-IDF与余弦相似性的应用(二):找出相似文章

转载于:https://my.oschina.net/penngo/blog/807810

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Spark计算文本相似度可以采用以下步骤: 1. 加载文本数据:使用Spark的DataFrame API加载文本数据,将每个文本转换为一个行记录。 2. 分词和特征提取:将每个文本进行分词并提取特征,这些特征可以是词频、TF-IDF等。 3. 计算相似度使用Spark的MLlib库中的相似度计算算法,如余弦相似度或欧几里得距离等,计算每对文本之间的相似度。 4. 结果展示:将相似度结果保存到DataFrame中,并进行展示和分析下面是一个简单的示例代码: ```python from pyspark.sql.functions import udf from pyspark.ml.feature import Tokenizer, HashingTF from pyspark.ml.feature import Normalizer from pyspark.ml.linalg import Vectors from pyspark.ml.feature import VectorAssembler from pyspark.ml.feature import BucketedRandomProjectionLSH from pyspark.sql.functions import col from pyspark.sql.types import IntegerType # 加载文本数据 df = spark.read.text("path/to/text/file.txt") # 分词和特征提取 tokenizer = Tokenizer(inputCol="value", outputCol="words") wordsData = tokenizer.transform(df) hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=10000) featurizedData = hashingTF.transform(wordsData) idf = IDF(inputCol="rawFeatures", outputCol="features") idfModel = idf.fit(featurizedData) rescaledData = idfModel.transform(featurizedData) # 计算相似度 normalizer = Normalizer(inputCol="features", outputCol="normFeatures") data = normalizer.transform(rescaledData) vectorAssembler = VectorAssembler(inputCols=["normFeatures"], outputCol="featuresVec") data = vectorAssembler.transform(data) brp = BucketedRandomProjectionLSH(inputCol="featuresVec", outputCol="hashes", bucketLength=0.1, numHashTables=20) model = brp.fit(data) similar = model.approxSimilarityJoin(data, data, 0.6) # 结果展示 similar = similar.filter(col("datasetA.id") < col("datasetB.id")) similar = similar.withColumn("id1", similar["datasetA.id"].cast(IntegerType())) similar = similar.withColumn("id2", similar["datasetB.id"].cast(IntegerType())) similar = similar.select("id1", "id2", "distCol") similar.show() ``` 在这个示例中,我们使用了哈希特征提取(HashingTF)和逆文档频率(IDF)转换来进行特征提取,然后使用了归一化器(Normalizer)对特征向量进行标准化。最后,我们使用了桶随机投影局部敏感哈希(BucketedRandomProjectionLSH)算法计算文本之间的相似度

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值