TF_TDF文本比较相似度算法

TF_TDF文本比较相似度算法

代码

package basic.util;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.wltea.analyzer.lucene.IKAnalyzer;

import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @Description:
 * @Author: ada
 * @Date: 2020/9/27 23:27
 * @Vervion: 1.0
 */
public class Tf_tdfUtil2 {
	private static final String TOTAL_COUNT1 = "totalCount1";
    private static final String TOTAL_COUNT2 = "totalCount2";
    private static final String FILE_COUNT = "fileCount";
    private static final String TF1 = "tf1";
    private static final String TF2 = "tf2";
    private static final String TDF1 = "tdf1";
    private static final String TDF2 = "tdf2";
    /*
     * 避免除数为0.
     */
    private static final double TDF_ASSIST = 1d;
    /*
     * @Desciption:  
     * 		总词数、每个单词出现数量n、每个词的频率tf,每个词出现出现在不同文件的次数d、文件总数1+D、tdf=log(D/d)、tf_tdf=tf*tdf.
     *      每个文件的信息指纹hash(判断是否被读取).
     *      每个文件单词表:单词,数量,出现文件数,tf、tdf、tf_tdf.
     *      文件表:文件、tf_tdf、分类、单词总数.
     * 	    这里有个前置条件,字数越少,形成的总词库就越小,会影响使用效果.
     * @Return: void
     * @Author: ada
     * @Date: 2020/9/27 23:39
     * @Version: 1.0
     */
    public static void main(String[] args) {
        try {
            String str1 = TextUtil.getFileContentByPath("E:\\myself\\tf_tdf\\1.txt");
            String str2 = TextUtil.getFileContentByPath("E:\\myself\\tf_tdf\\2.txt");
            System.out.println("***************************************");
            System.out.println("内容1:" + str1);
            System.out.println("内容2:" + str1);
            System.out.println("内容1与内容2的相似度为(1:完全相似,0:完全不相似)为:" + getTfTdfValue(str1, str1));
            System.out.println("***************************************");
            System.out.println("***************************************");
            System.out.println("内容1:" + str1);
            System.out.println("内容2:" + str2);
            System.out.println("内容1与内容2的相似度为(1:完全相似,0:完全不相似)为:" + getTfTdfValue(str1, str2));
            System.out.println("***************************************");

        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /*
     * @Desciption: 得到两个文本的余弦值.
     * @param str1   文本1.
     * @param str2   文本2.
     * @Return: double   余弦值.
     * @Author: ada
     * @Date: 2021/3/12 22:43
     * @Version: 1.0
     */
    public static double getTfTdfValue(String str1, String str2) throws Exception {
        // 获取str1的单词表
        List<String> charList1 = cutChar(str1);
        // 获取str2的单词表
        List<String> charList2 = cutChar(str2);
        // 获取str1的单词向量表
        Map<String, Map<String, Double>> charVetorTable = getCharVetorTable(charList1, charList2);
        return getCos(charVetorTable);
    }


    /*
     * @Desciption: 得到字符串str1的单词向量表1.
     * @param charList1     单词表1.
     * @param charList2     单词表2.
     * @Return: java.util.Map<java.lang.String,java.lang.String>
     * @Author: ada
     * @Date: 2021/3/12 22:46
     * @Version: 1.0
     */
    public static Map<String, Map<String, Double>> getCharVetorTable(List<String> charList1, List<String> charList2) {
        Map<String, Map<String, Double>> vectorMap = new HashMap<>();
        charList1.forEach(chars -> {
            if (vectorMap.containsKey(chars)) {
                Map<String, Double> vector = vectorMap.get(chars);
                vector.put(TOTAL_COUNT1, vector.get(TOTAL_COUNT1) + 1);
                vector.put(FILE_COUNT, vector.get(FILE_COUNT) + 1);
            } else {
                Map<String, Double> vector = new HashMap<>();
                vector.put(TOTAL_COUNT1, 1d);
                vector.put(TOTAL_COUNT2, 0d);
                vector.put(FILE_COUNT, 1d);
                vectorMap.put(chars, vector);
            }
        });
        charList2.forEach(chars -> {
            if (vectorMap.containsKey(chars)) {
                Map<String, Double> vector = vectorMap.get(chars);
                vector.put(TOTAL_COUNT2, vector.get(TOTAL_COUNT2) + 1);
                vector.put(FILE_COUNT, vector.get(FILE_COUNT) + 1);
            } else {
                Map<String, Double> vector = new HashMap<>();
                vector.put(TOTAL_COUNT1, 0d);
                vector.put(FILE_COUNT, 1d);
                vector.put(TOTAL_COUNT2, 1d);
                vectorMap.put(chars, vector);
            }
        });
        // 两个字符串的总词数
        double totalCharCount = charList1.size() + charList2.size();
        // 得到词频(tf)和逆文档频率(tdf)
        vectorMap.forEach((key, map) -> {
            double totalCount1 = map.get(TOTAL_COUNT1);
            double totalCount2 = map.get(TOTAL_COUNT2);
            double fileCount = map.get(FILE_COUNT);
            double tf1 = totalCount1 / totalCharCount;
            double tdf1 = Math.log(2 / (fileCount + TDF_ASSIST));
            double tf2 = totalCount2 / totalCharCount;
            double tdf2 = Math.log(2 / (fileCount + TDF_ASSIST));
            map.put(TF1, tf1);
            map.put(TDF1, tdf1);
            map.put(TF2, tf2);
            map.put(TDF2, tdf2);
        });
        return vectorMap;
    }

    /*
     * @Desciption: 得到两个文本的余弦值.
     * @param str1   文本1.
     * @param str2   文本2.
     * @Return: double   余弦值.
     * @Author: ada
     * @Date: 2021/3/12 22:43
     * @Version: 1.0
     */
    public static BigDecimal getCos(Map<String, Map<String, Double>> charVectorTable) {
        BigDecimal sum1 = BigDecimal.ZERO;
        BigDecimal sum2 = BigDecimal.ZERO;
        for (Map.Entry<String, Map<String, Double>> entry : charVectorTable.entrySet()) {
            Map<String, Double> mapValue = entry.getValue();
            sum1 = sum1.add(BigDecimal.valueOf(mapValue.get(TF1)).multiply(BigDecimal.valueOf(mapValue.get(TDF1))));
            sum2 = sum2.add(BigDecimal.valueOf(mapValue.get(TF2)).multiply(BigDecimal.valueOf(mapValue.get(TDF2))));
        }
        BigDecimal result1 = sum2.compareTo(BigDecimal.ZERO) == 0 ? BigDecimal.ZERO : sum1.divide(sum2, 6);
        BigDecimal result2 = sum1.compareTo(BigDecimal.ZERO) == 0 ? BigDecimal.ZERO : sum2.divide(sum1, 6);
        return result1.compareTo(result2) < 0 ? result1 : result2;
    }

    public static List<String> cutChar(String str) throws Exception {
        if(StringUtils.isEmpty(str)){
            throw new Exception("字符串为空");
        }
        List<String> stringList=new ArrayList<>();
        // 创建分词对象
        Analyzer analyzer = new IKAnalyzer(true);
        StringReader reader = new StringReader(str);

        // 分词
        TokenStream ts = analyzer.tokenStream("", reader);
        ts.reset();
        CharTermAttribute term = ts.getAttribute(CharTermAttribute.class);

        // 遍历分词数据
        while(ts.incrementToken()){
            stringList.add(term.toString());
        }
        reader.close();
        return stringList;
    }

}


运行图

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值