BM25实现句子匹配

1、仅记录BM25实现,并做封装

import numpy as np
from collections import Counter
import jieba
import re
# 封装   导入包
from flask import request, Flask
from flask_cors import CORS
import json

# flask格式
app = Flask(__name__)
CORS(app, supports_credentials=True)
# 解决乱码问题
app.config['JSON_AS_ASCII']=False
# 满足get/pos请求
@app.route("/similarity", methods=["GET", "POST"]) 



def new_flask():
    '''
    documents_list 表示需要输入的文本列表,内部每个文本需要事先分好词
    documents_number表示文本总个数
    avg_documents_len 表示所有文本的平均长度
    f 用于存储每个文本中每个词的出现的次数
    idf 用于存储每个词汇的权重值
    init 函数是类初始化函数, 用于求解文本集合中的f和idf变量
    get_score 函数是获取一个文本与文本列表中一个文本的bm25相似度值
    get_documents_score 函数是获取一个文本与文本列表中所有文本的bm25相似度值
    '''
    
    class BM25_Model(object):
        def __init__(self, documents_list, k1=2, k2=1, b=0.5):
            self.documents_list = documents_list
            self.documents_number = len(documents_list)
            self.avg_documents_len = sum([len(document) for document in documents_list]) / self.documents_number
            self.f = []
            self.idf = {}
            self.k1 = k1
            self.k2 = k2
            self.b = b
            self.init()

        def init(self):
            df = {}
            for document in self.documents_list:
                temp = {}
                for word in document:
                    temp[word] = temp.get(word, 0) + 1
                self.f.append(temp)
                for key in temp.keys():
                    df[key] = df.get(key, 0) + 1
            for key, value in df.items():
                self.idf[key] = np.log((self.documents_number - value + 0.5) / (value + 0.5))

        def get_score(self, index, query):
            score = 0.0
            document_len = len(self.f[index])
            qf = Counter(query)
            for q in query:
                if q not in self.f[index]:
                    continue
                score += self.idf[q] * (self.f[index][q] * (self.k1 + 1) / (
                            self.f[index][q] + self.k1 * (1 - self.b + self.b * document_len / self.avg_documents_len))) * (
                                    qf[q] * (self.k2 + 1) / (qf[q] + self.k2))

            return score

        def get_documents_score(self, query):
            score_list = []
            for i in range(self.documents_number):
                score_list.append(self.get_score(i, query))
            return score_list


    # 获取停用词列表
    def get_stopwords_list():
        stopwords = [line.strip() for line in open('data/stopwords.txt', "r", encoding='utf-8').readlines()]
        return stopwords

    # 对用户问题进行分词操作
    def seg_depart(sentence):
        # 对文档中的每一行进行中文分词
        sentence_depart = jieba.lcut(sentence.strip())
        return sentence_depart

    def remove_digits(input_str):
        punc = u'0123456789.'
        output_str = re.sub(r'[{}]+'.format(punc), '', input_str)
        return output_str


    # 标准问题库去除停用词
    def move_stopwords_ku(sentence_list, stopwords_list):
        # 去停用词
        out_list = []
        new_list = []
        for word in sentence_list:
            # print(word)     # 此时是一个list
            for new_word in word:
                # print(new_word)
                if new_word not in stopwords_list:
                    if not remove_digits(new_word):
                        continue
                    if new_word != '\t':
                        new_list.append(new_word)
            out_list.append(new_list)
            # append一次之后将此list置空,否则会重复添加之前遍历的内容
            new_list = []
        return out_list


    # 用户问题去除停用词
    def move_stopwords(sentence_list, stopwords_list):
        # 去停用词
        out_list = []
        for word in sentence_list:
            if word not in stopwords_list:
                if not remove_digits(word):
                    continue
                if word != '\t':
                    out_list.append(word)
        return out_list



    # 获取停用此表
    stopwords = get_stopwords_list()

    # 获取问题库数据
    def get_sentence_list():
        # ATEC  /  BQ   /   LCQMC
        sentence_list = [line.strip() for line in open('data/all_data.txt', "r", encoding='utf-8').readlines()]
        # print(stopwords[:5])
        return sentence_list

    # 标准问库
    document_list = ["行政机关强行解除行政协议造成损失,如何索取赔偿?",
                    "借钱给朋友到期不还得什么时候可以起诉?怎么起诉?",
                    "我在微信上被骗了,请问被骗多少钱才可以立案?",
                    "公民对于选举委员会对选民的资格申诉的处理决定不服,能不能去法院起诉吗?",
                    "有人走私两万元,怎么处置他?",
                    "法律上餐具、饮具集中消毒服务单位的责任是不是对消毒餐具、饮具进行检验?"]


    # 对标准问库document_list进行中文分词
    document_list = [list(jieba.cut(doc)) for doc in document_list]

    # 对库sentence_list进行分词
    sentence_list = get_sentence_list()
    document_list = [list(jieba.cut(doc)) for doc in sentence_list]
    # print(sentence_list[:5])

    # 标准问题库去除停用词
    document_list = move_stopwords_ku(document_list, stopwords)

    # 实例化BM25类,生成一个对象,默认k1​, k2 ​和 b使用默认值
    bm25_model = BM25_Model(document_list)

    '''
    # 参数调用,观察示例化的对象中documents_list ,documents_number,avg_documents_len ,f 和idf变量具体存储了什么
    # print("\nbm25_model.documents_list")
    # print(bm25_model.documents_list)
    # print("\n*bm25_model.documents_number")
    # print(bm25_model.documents_number)
    # print("\n*bm25_model.avg_documents_len")
    # print(bm25_model.avg_documents_len)
    # print("\n*bm25_model.f")
    # print(bm25_model.f)
    # print("\n*bm25_model.idf")
    # print(bm25_model.idf)
    '''

    # 用户问题
    # query_zero = "可以使用蚂蚁花呗充值加油卡吗"
    # flask请求数据
    def form_or_json():
            if request.get_json(silent=True):
                return request.get_json(silent=True)
            else:
                if request.form:
                    return request.form
                else:
                    return request.args
    data = form_or_json()
    query_zero =data['text']
    print(query_zero)


    # 对用户问题分词
    query = list(jieba.cut(query_zero))

    # 对用户问题去除停用词
    query = move_stopwords(query, stopwords)

    # 输出相似度(对应相似标准库中的值,即标准库中多少数据就有多少个数据输出,最大的为最相似!)
    scores = bm25_model.get_documents_score(query)

    # 输出结果
    # print(f"document_scores:{scores}")
    # 分词结果
    print(f"The most similarity to {query} is {document_list[np.argmax(scores)]}")
    # 未分词结果
    print(f"The most similarity to {query_zero} is {sentence_list[np.argmax(scores)]}")
    # 格式list,不然请求失败
    answer_data = [sentence_list[np.argmax(scores)]]
    # print(answer_data)
    return(answer_data)

if __name__ == '__main__' :
    app.run("0.0.0.0", "5002", debug=True)
    

2、测试请求

# 传递数据的请求
import json
import requests

REQUEST_URL = "http://127.0.0.1:5002/similarity"
HEADER = {'Content-Type':'application/json; charset=utf-8'}

# requestDict = {}
# requestDict["text"] = input("请输入文本:")
requestDict = {"text": "可以使用蚂蚁花呗充值加油卡吗"}

rsp = requests.post(REQUEST_URL, json.dumps(requestDict), headers=HEADER)
rspJson = json.loads(rsp.text.encode())
print(rspJson)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值