BM25搜索相关性评分算法

package com.btg.core.util.bm25;

import org.wltea.analyzer.core.IKSegmenter;
import org.wltea.analyzer.core.Lexeme;

import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * https://www.jianshu.com/p/1e498888f505
 * https://www.cnblogs.com/jiangxinyang/p/10516302.html
 * https://www.zybuluo.com/evilking/note/902621
 * https://github.com/hankcs/HanLP/blob/master/src/main/java/com/hankcs/hanlp/summary/BM25.java
 * https://github.com/haifengl/smile/blob/master/nlp/src/main/java/smile/nlp/relevance/BM25.java
 * https://github.com/jllan/jannlp
 *
 * BM25 调参调研 https://www.cnblogs.com/NaughtyBaby/p/9774836.html
 *
 * 搜索相关性评分算法
 * 计算句子与文本的相似度
 */
public class BM25Test {

    // 调节因子
    private static final double k1 = 1.5;
    private static final double b = 0.75;

    /**
     * 文本集合
     * @return
     */
    private static List<String> sentences() {
        List<String> sentences = new ArrayList<>();
        sentences.add("Elasticsearch is a highly scalable open-source full-text search and analytics engine");
        sentences.add("It allows you to store, search, and analyze big volumes of data quickly and in near real time");
        sentences.add("is generally used as the underlying engine/technology that powers applications that have complex search features and requirements");
        sentences.add("You run an online web store where you allow your customers to search for products that you sell");
        sentences.add("You want to collect log or transaction data and you want to analyze and mine this data to look for trends, statistics, summarizations, or anomalies");
        return sentences;
    }

    /**
     * 查询语句
     * @return
     */
    private static String query() {
        return "want";
    }

    public static void main(String[] args) {
        List<String> sentences = sentences();
        // 文档句子的个数
        final int N = sentences.size();
        // 文档句子的平均长度
        final double avgdl = avgdl(sentences);
        System.out.println("avgdl = " + avgdl);

        String query = query();
        System.out.println("查询语句 = " + query);

        List<String> queryWords = ikanalyzer(query);
        System.out.println("查询语句分词结果 = " + queryWords);

        System.out.println("--------------------------------------------------------");

        System.out.println("文档集合 = " + sentences);
        List<List<String>> allWords = new ArrayList<>();
        // 文本集合分词
        for(int i = 0, len = sentences.size(); i < len; i++) {
            allWords.add(i, ikanalyzer(sentences.get(i)));
        }
        System.out.println("文档集合分词结果 = " + allWords);
        System.out.println("--------------------------------------------------------");

        List<Map<String, Integer>> fs = f(allWords);
        System.out.println("文档中每个句子中的每个词与词频 = " + fs);
        System.out.println("--------------------------------------------------------");

        Map<String, Integer> nqis = nqi(allWords, queryWords);
        System.out.println("包含查询词的文本数量 = " + nqis);
        System.out.println("--------------------------------------------------------");

        Map<String, Double> idfs = idf(N, queryWords, nqis);
        System.out.println("统计每个查询词在文本集合中的权重 = " + idfs);
        System.out.println("--------------------------------------------------------");

        for(int i = 0, len = sentences.size(); i < len; i++) {
            String sentence = sentences.get(i);
            double res = 0;
            for(String qw : queryWords) {
//                if(!sentence.contains(qw)) {
//                    continue;
//                }
                // fi为qi在文本d中出现的频率
                Double wi = idfs.get(qw);
                int fi = fs.get(i).getOrDefault(qw, 0);
//                System.out.println(sentence + " qw = " + qw + " fi = " + fi);
                double R = fi * (k1 + 1) / (fi + K(sentence, avgdl));
                res += wi * R;
            }
            System.out.println("句子 = " + sentence + " 相似度 = " + res);
        }
    }



    /**
     * 分词
     * @return
     */
    private static List<String> ikanalyzer(String line) {
        StringReader re = new StringReader(line);
        IKSegmenter ik = new IKSegmenter(re,true);
        Lexeme lex = null;
        List<String> words = new ArrayList<>();
        try {
            while((lex = ik.next()) != null){
                String text = lex.getLexemeText();
                words.add(text);
            }
            return words;
        }catch (Exception e) {
            e.printStackTrace();
        }
        return words;
    }

    /**
     * 包含qi这个词的文本的数量
     * @return
     */
    private static Map<String, Integer> nqi(List<List<String>> allWords, List<String> queryWords) {
        Map<String, Integer> nqis = new TreeMap<>();
        for(String qw : queryWords) {
            for(List<String> aws : allWords) {
                if(aws.contains(qw)) {
                    nqis.put(qw, nqis.getOrDefault(qw, 0) + 1);
                }
            }
        }
        return nqis;
    }

    /**
     * 计算w(i),q(i)权重
     * 词与文档的相关度
     * @return
     */
    private static Map<String, Double> idf(int N, List<String> queryWords, Map<String, Integer> nqis) {
        Map<String, Double> idfs = new HashMap<>();
        for(String qw : queryWords) {
//            System.out.println("包含 " + qw + " 的文本数量 = " + nqis.getOrDefault(qw, 0));
            double temp = (N - nqis.getOrDefault(qw, 0) + 0.5) / (nqis.getOrDefault(qw, 0) + 0.5);
            double idf = Math.log10(1 + temp);
//            System.out.println("查询词 " + qw + ", idf = " + idf);
            idfs.put(qw, idf);
        }
        return idfs;
    }

    /**
     * 文档中每个句子中的每个词与词频
     * @param allWords
     * @return
     */
    private static List<Map<String, Integer>> f(List<List<String>> allWords) {
        List<Map<String, Integer>> fs = new ArrayList<>();
        for(List<String> aw : allWords) {
            Map<String, Integer> map = new TreeMap<>();
            for(String w : aw) {
                map.put(w, map.getOrDefault(w, 0) + 1);
            }
            fs.add(map);
        }
        return fs;
    }

    /**
     * 文本集D中所有文本的平均长度
     * @param sentences
     * @return
     */
    private static double avgdl(List<String> sentences) {
        double totalLen = 0.0;
        for(String sentence : sentences) {
            totalLen += sentence.length();
        }
        return totalLen / sentences.size();
    }

    private static double K(String sentence, double avgdl) {
        int dl = sentence.length();
        return k1 * (1 - b + b * dl / avgdl);
    }

}

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值