金融LLM问答系统record

项目背景

随着人工智能和数字化技术在经济领域的广泛应用,经济领域对于高质量的检索和问答需求日益增长,而大模型的能力在这方面具有巨大的潜力。但普通的大模型无法回答很多商业问题,因为大模型并没有相关数据收集, 已进行测试大部分商业问题无法回答。

总体思路

  1. 收集股票相关数据
  2. 收集股民问题(论坛爬取)
  3. 将问题统计词频,使用算法分类,构造RAG-ICL样本库,编写校对SQL语句和回答
  4. 问题进入后分类

对于SQL查询问题,新问题使用Qwen大模型的tokenizer统计词频,与RAG-ICL样本库中的问题比较,选取最相似的2-4个问题-SQL语句对加入prompt,利用Qwen大模型做“填空与替换”,生成高可解释性与可靠性的SQL查询。运行查询,利用Qwen大模型将查询结果和问题生成为答案。

对于文本理解问题,新问题使用Qwen大模型的tokenizer统计词频,与分段的Text+表格文件计算总词频加权的余弦相似度,选出与问题最相关的20个文本片段,利用他们生成答案。

项目特点

项目特点

  1. 大模型基础
    • 本项目基于 Qwen大模型,这是一个强大的大语言模型。
  2. 数据集
    • 本项目使用包括10张基金表数据和80篇招股书文档,数据较为丰富,涵盖了基金的基本信息、股票持仓明细、债券持仓明细、可转债持仓明细、日行情等。
    • 数据时间跨度为2019年至2021年。
  3. 以检索增强上下文学习(RAG-ICL)为核心,使得我们可以在仅标注较少数据集 (不到200条),不精调模型的情况下快速得到一个效果较好的问答系统。

同时本方案具有较高的可解释性(只需修改RAG-ICL样本库中的例子就可以修改生成效果),对于可见的新问题类型也无需调整模型,只需补充RAG-ICL样本即可

前期数据准备

我们收集整理好一些金融相关问题的Data,大致如下

  • 基金基本信息
  • 基金股票持仓明细
  • 基金债券持仓明细
  • 基金可转债持仓明细
  • 基金日行情表
  • A股票日行情表
  • 港股票日行情表
  • A股公司行业划分表
  • 基金规模变动表
  • 基金份额持有人结构'

还有一部分招股说明书

其中部分数据是db,部分是表格,我们统一转化为db

如果碰到pdf格式,转为txt方便处理

在Python中,我们读取问题文件(问题收集于金融论坛,爬虫而得)

def read_jsonl(path):
    content = []
    with jsonlines.open(path, "r") as json_file:
        for obj in json_file.iter(type=dict, skip_invalid=True):
            content.append(obj)
    return content

将带有列名 “问题id” 和 “问题” 的标题行写入

def write_jsonl(path, content):
    with jsonlines.open(path, "w") as json_file:
        json_file.write_all(content)

还要把所有空格删除,最后写入CSV文件方便后续处理

知识收集

Tokenizer

Tokenizer是自然语言处理(NLP)中的一个关键组件,用于将文本分割成较小的单元(通常称为“tokens”)。这些单元可以是单词、子词、字符或其他语言学意义上的单位。Tokenizer的主要目的是将原始文本转换为模型可以处理的格式。

  1. Word Tokenizer(词级别分词器):将文本分割成单个单词。例如,"I love NLP"会被分割成["I", "love", "NLP"]。

  2. Subword Tokenizer(子词级别分词器):将文本分割成更小的单元,可以是词的一部分。这在处理新词、罕见词或多语言文本时非常有用。常见的子词分词器包括Byte-Pair Encoding (BPE)、WordPiece和SentencePiece。例如,"unhappiness"可能会被分割成["un", "happiness"]。

  3. Character Tokenizer(字符级别分词器):将文本分割成单个字符。例如,"NLP"会被分割成["N", "L", "P"]。

在LLM中,Tokenizer有着关键的作用,它是是文本预处理的第一步,还直接影响到模型的性能和处理能力。Tokenizer有以下作用

  1. 文本预处理:将原始文本分割成token,并将其转换为模型可以理解的输入格式(模型不能直接理解或处理原始的文本数据。相反,文本数据需要转换成一种模型可以处理的数值格式。)
  2. 处理未见词:大模型需要能够处理未在训练集中出现的词语
  3. 节省计算资源:通过使用子词tokenizer,可以减少词汇表的大小,从而减少模型的参数量和计算资源需求。

RAG

我们的系统的核心是RAG(Retrieval-Augmented Generation),这是一种结合检索与生成模型的方法,用于提升生成模型的准确性和效果。将外部知识库的检索机制与生成式语言模型结合起来,以生成更加准确和上下文相关的文本回答。

RAG的工作原理

RAG结合了检索(Retrieval)和生成(Generation)两种方法:

  1. 检索:首先从一个大型文档集合(例如知识库、数据库)中检索相关的文档片段。
  2. 生成:然后将检索到的文档片段与原始问题一起输入到生成模型,生成回答。生成模型在生成回答时,可以利用检索到的文档片段中的信息,提升回答的准确性和相关性。

我们将新问题和与新问题最相似的模板输入大模型,大模型输出新问题对应的SQL语句.,这样我们的模型只需要补足,而不是从头生成。

我们使用RAG模型的主要目的是为了生成SQL查询的自然语言描述

代码的主要任务是执行SQL查询,并将结果以自然语言的形式描述出来。这种任务需要将结构化的数据库查询结果转换为易于理解的自然语言描述。

优势

使用RAG有独有的优势:

  • 能够实时利用最新的信息,不受训练数据的限制。
  • 提高了模型在长尾数据和稀有知识点上的性能。

TF-IDF理论

TF-IDF是一种用于衡量词语在文档集中的重要性的方法。它由两个部分组成:词频(Term Frequency, TF)和逆文档频率(Inverse Document Frequency, IDF)。

词频衡量一个词在文档中出现的频率。对于一个给定的词 t 和文档 d,其词频可以表示为:

逆文档频率衡量一个词在整个文档集中的重要性。一个词在越多的文档中出现,其IDF值越低。对于词 t 和文档集 D,其逆文档频率可以表示为:

将词频和逆文档频率相乘,得到词 t 在文档 d 中的TF-IDF值:TF-IDF值越高,表示词 t 在文档 d 中越重要。

余弦相似度是一种衡量两个向量在向量空间中夹角余弦值的相似性度量方法。余弦相似度通常用于计算文本向量之间的相似性。

将文档表示为TF-IDF向量后,我们就可以使用余弦相似度来计算文档之间的相似性。

问题分类

我们使用预训练语言模型对问题进行分类,判断它们是否与公司的招股说明书(招股说明书)或基金股票数据库(基金股票数据库)相关。

加载模型和分词器

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

读入包含问题数据和公司数据的文件,加载预训练语言模型,根据回答将问题分类好写入csv文件

model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cuda:0", trust_remote_code=True, bf16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_dir,
                                                           trust_remote_code=True,
                                                           temperature=0.0000001,
                                                           top_p=1,
                                                           do_sample=False,
                                                           seed=1234)
print('A01_model_loaded')

g = open('/app/intermediate/A01_question_classify.csv', 'w', newline='', encoding='utf-8-sig') 
csvwriter = csv.writer(g)
csvwriter.writerow(['问题id','问题','答案','分类'])

此处的prompt经过测试如下的效果较好

prompt = """
    你是一个问题分类器。对于每个提供给你的问题,你需要猜测答案是在该公司的招股说明书中还是在基金股票数据库里。以下是一些例子:

    问题:“在2019年的中期报告里,XX基金管理有限公司管理的基金中,有多少比例的基金是个人投资者持有的份额超过机构投资者?希望得到一个精确到两位小数的百分比。”
    回答:“基金股票数据库”
    
    问题:“XXXX股份有限公司变更设立时作为发起人的法人有哪些?”
    回答:“该公司的招股说明书”
    
    (省略其他示例问题)
    
    根据上面提供的例子对以下问题进行分类。
    问题:“
    """

循环处理每个问题,将问题加入提示模板,并调用模型生成回答。根据模型的回答判断问题类别(招股说明书或股票数据库),并检查问题中是否包含公司名称,以进行进一步分类。将分类结果写入CSV文件并定期输出进度。

for cyc in range(len(new_question_file)):
    temp_question = new_question_file.iloc[cyc]['问题']

    prompt1 = prompt + temp_question + """?"""

    response_new, history_new = model.chat(tokenizer, prompt1, history=None)
    if cyc % 100 == 0:
        print(str(new_question_file.iloc[cyc]['问题id']))

    if '招股说明书' in response_new and '股票数据库' not in response_new:
        temp_class = 'Text'
    elif '招股说明书' not in response_new and '股票数据库' in response_new:
        temp_class = 'SQL'
        for company_name in company_list:
            if company_name in temp_question:
                temp_class = 'Text'
    else:
        temp_class = 'SQL'
        for company_name in company_list:
            if company_name in temp_question:
                temp_class = 'Text'
    if cyc in [166,174]:
        temp_calss = 'Text'


    csvwriter.writerow([str(new_question_file.iloc[cyc]['问题id']),
                    str(new_question_file.iloc[cyc]['问题']),
                    response_new, temp_class])
g.close()

exit()

 问题与对应公司实体关联

对于分类为 "Text" 的问题:

    if new_question_file.iloc[cyc]['分类'] == 'Text':
        temp_index_q = tokenizer(new_question_file.iloc[cyc]['问题'])
        temp_index_q = temp_index_q['input_ids']
        q_cp_similarity_list = list()
        for cyc2 in range(len(company_file)):
            temp_index_cp = company_index_list[cyc2]
            temp_simi = len(set(temp_index_cp) & set(temp_index_q)) / (len(set(temp_index_cp)) + len(set(temp_index_q)))
            q_cp_similarity_list.append(temp_simi)
            
        t = copy.deepcopy(q_cp_similarity_list) 
        max_number = []
        max_index = []
        
        for _ in range(1):
            number = max(t)
            index = t.index(number)
            t[index] = 0
            max_number.append(number)
            max_index.append(index)
        t = []
        tempw_entity = company_name_list[max_index[0]]
        tempw_csv_name = company_data_csv_list[max_index[0]]

对问题进行分词,计算问题分词ID与公司名称分词ID的相似度,通过集合交集计算相似度比例,找到最相似的公司名称及其对应的CSV文件名。

分类为sql直接写入csv文件

其他情况则检查问题中是否包含公司名称,并根据包含的公司名称确定对应实体和CSV文件名。如果没有找到匹配的公司名称,则默认写入"N_A"。

sql问题

此部分内容将已经分类好的问题生成sql,用已经生成好的sql去数据库中查询

文本理解问题

对于文本理解问题,我们读取问题后在公司报告数据找到与问题最相关的文本片段,将最终结果输出到csv文件。

具体寻找过程我们利用余弦相似度

def counter_cosine_similarity(c1, c2, normalized_dict):
    terms = set(c1).union(c2)
    dotprod = sum(c1.get(k, 0) * c2.get(k, 0)/normalized_dict.get(k,1) for k in terms)
    magA = math.sqrt(sum(c1.get(k, 0)**2/(normalized_dict.get(k,1)**2) for k in terms))
    magB = math.sqrt(sum(c2.get(k, 0)**2/(normalized_dict.get(k,1)**2) for k in terms))
    
    if magA * magB != 0:
        return dotprod / (magA * magB)
    else:
        return 0

 找到相似度最高的前n个片段,并记录它们的索引和相似度值,将问题ID、问题、实体、相关文件名、相似度最高的片段索引和相似度值写入CSV文件。

生成答案

我们获得了问题后,将问题传入预训练模型,根据问题从文本片段中生成答案。

读取两个CSV文件,分别存储了问题及其相关的文本片段和其他辅助信息。

加载语言模型,并设置生成配置。

定义停用词列表

开始循环处理问题。

Prompt

Prompt是用来引导或提示模型生成期望输出的输入文本。Prompt告诉模型需要执行什么任务。例如,给定一个问题,模型需要生成答案;给定一个不完整的句子,模型需要补全。

Prompt也会提供上下文信息,使模型能够在生成响应时考虑相关的背景,设计不同的Prompt,可以控制模型的输出风格、内容和格式。

需要注意的是,Prompt将用户输入问题转化为模型可以理解的形式,引导模型生成特定响应。通过精心设计Prompt,可以显著提升大模型在各种任务中的表现。

BQA_Prompt

此代码模块用于在处理不同类型任务时生成合适的Prompt,将复杂任务拆分为多个步骤,利用模板和类的灵活性,使得整个过程更加模块化。

总体思路:

用户问题 -> 初始化系统Prompt -> 生成用户问题Prompt -> 更新当前任务Prompt -> 生成包含当前任务的Prompt -> 生成SQL查询 -> 执行SQL查询获取结果 -> 处理和融合结果 -> 返回最终结果

代码 

from .prompt import PromptGenerator

BQA_PLAN_DEFAULT_PROMPT = "你是一名高级智能助手,你可以先对问题进行分类,问题类型只有公司招股书咨询和股票基金数据查询两类,然后根据所给的信息列出回答该问题的任务列表。股票基金数据查询提供的表如下:A股公司行业划分表, A股票日行情表, 基金份额持有人结构, 基金债券持仓明细, 基金可转债持仓明细, 基金基本信息, 基金日行情表, 基金股票持仓明细, 基金规模变动表, 港股票日行情表。"
BQA_TASK_DEFAULT_PROMPT = "你是一名高级智能助手,你需要根据当前提供的信息,执行当前任务。"
BQA_CHAIN_PROMPT = "你是一名高级智能助手, 你需要针对用户问题,选择使用合适的插件。"
BQA_TASK_INSTRUCTION_TEMPLATE = """当前任务可以使用的插件信息如下,请自行判断是否需要调用插件来解决当前用户问题。若需要调用插件,则需要将插件调用请求按照json格式给出,必须包含api_name、parameters字段,并在其前后使用<|startofthink|>和<|endofthink|>作为标志。\
若无需调用插件,直接执行任务,结果无需标志。
{tool_list}"""

SCHEME_STRUCTURE_DICT = {
    'A股公司行业划分表':
        '''
        字段 类型
        股票代码 TEXT 
        交易日期 TEXT
        行业划分标准 TEXT
        一级行业名称 TEXT
        二级行业名称 TEXT
        ''',
    'A股票日行情表':
        '''
        字段 类型
        股票代码 TEXT
        交易日 TEXT
        [昨收盘(元)] REAL
        [今开盘(元)] REAL
        [最高价(元)] REAL
        [最低价(元)] REAL
        [收盘价(元)] REAL
        [成交量(股)] REAL
        [成交金额(元)] REAL
        ''',
    '基金份额持有人结构':
        '''
        字段 类型
        基金代码 TEXT
        基金简称 TEXT
        公告日期 TIMESTAMP
        截止日期 TIMESTAMP
        机构投资者持有的基金份额 REAL
        机构投资者持有的基金份额占总份额比例 REAL
        个人投资者持有的基金份额 REAL
        个人投资者持有的基金份额占总份额比例 REAL
        定期报告所属年度 INTEGER
        报告类型 TEXT
        ''',
    '基金债券持仓明细':
        '''
        字段 类型
        基金代码 TEXT
        基金简称 TEXT
        持仓日期 TEXT
        债券类型 TEXT
        债券名称 TEXT
        持债数量 REAL
        持债市值 REAL
        持债市值占基金资产净值比 REAL
        第N大重仓股 INTEGER
        所在证券市场 TEXT
        [所属国家(地区)] TEXT
        报告类型TEXT TEXT
        ''',
    '基金可转债持仓明细':
        '''
        字段 类型
        基金代码 TEXT
        基金简称 TEXT
        持仓日期 TEXT
        对应股票代码 TEXT
        债券名称 TEXT
        数量 REAL
        市值 REAL
        市值占基金资产净值比 REAL
        第N大重仓股 INTEGER
        所在证券市场 TEXT
        [所属国家(地区)] TEXT
        报告类型 TEXT
        ''',
    '基金基本信息':
        '''
        字段 类型
        基金代码 TEXT
        基金全称 TEXT
        基金简称 TEXT
        管理人 TEXT
        托管人 TEXT
        基金类型 TEXT
        成立日期 TEXT
        到期日期 TEXT
        管理费率 TEXT
        托管费率 TEXT
        ''',
    '基金日行情表':
        '''
        字段 类型
        基金代码 TEXT
        交易日期 TEXT
        单位净值 REAL
        复权单位净值 REAL
        累计单位净值 REAL
        资产净值 REAL
        ''',
    '基金股票持仓明细':
        '''
        字段 类型
        基金代码 TEXT
        基金简称 TEXT
        持仓日期 TEXT
        股票代码 TEXT
        股票名称 TEXT
        数量 REAL
        市值 REAL
        市值占基金资产净值比 REAL
        第N大重仓股 INTEGER
        所在证券市场 TEXT
        [所属国家(地区)] TEXT
        报告类型 TEXT
        ''',
    '基金规模变动表':
        '''
        字段 类型
        基金代码 TEXT
        基金简称 TEXT
        公告日期 TIMESTAMP
        截止日期 TIMESTAMP
        报告期期初基金总份额 REAL
        报告期基金总申购份额 REAL
        报告期基金总赎回份额 REAL
        报告期期末基金总份额 REAL
        定期报告所属年度 INTEGER
        报告类型 TEXT
        ''',
    '港股票日行情表':
        '''
        字段 类型
        股票代码 TEXT
        交易日 TEXT
        [昨收盘(元)] REAL
        [今开盘(元)] REAL
        [最高价(元)] REAL
        [最低价(元)] REAL
        [收盘价(元)] REAL
        [成交量(股)] REAL
        [成交金额(元)] REAL
        '''
}

BQA_USER_QUESTION_TEMPLATE = "用户问题:{user_question}"
BQA_CURRENT_TASK_TEMPLATE = "当前任务:{current_task}"

BQA_CHAT_KNOWLEDGE_TEMPLATE = """------检索内容开始------
{extra_knowledge}
------检索内容结束------

用户问题:{user_question}。
完全根据检索内容结合问题回答用户问题,将问题和答案结合后输出。注意不要输出“根据检索”。
"""

# BQA_CHAT_KNOWLEDGE_TEMPLATE="""------检索内容开始------
# {extra_knowledge}
# ------检索内容结束------

# 用户问题:{user_question}。
# 完全根据检索内容结合问题回答用户问题,将问题和答案结合后输出;
# 若在检索内容中无答案,输出“问题” + “并未在招股意向书中详细说明”,如用户问题:“上海华铭智能终端设备股份有限公司的首发战略配售结果如何?”,输出:“上海华铭智能终端设备股份有限公司的首发战略配售具体情况并未在招股意向书中详细说明。”。"""

BQA_SQL_GENERATOR_TEMPLATE = """你是一名高级数据库工程师,请你根据所提供的表结构说明以及用户问题,生成sql语句,数据库为sqlite,你生成的sql语句格式必须符合sqlite格式。
------表结构说明开始------
{table_structure_introduction}
------表结构说明结束------

用户问题:{user_question}。
注意:答案只需要sql语句,不需要其他任何输出。
"""

BQA_SQL_GENERATOR_TEMPLATE_1 = "你是一名sqlite数据库开发人员,精通sql语言,你需要根据已知的10张表的表名、字段名和用户输入的问题编写sql\n\n" \
                               "{'表名': '基金基本信息', '字段名': ['基金代码', '基金全称', '基金简称', '管理人', '托管人', '基金类型', '成立日期', '到期日期', '管理费率', '托管费率']}\n" \
                               "{'表名': '基金股票持仓明细', '字段名': ['基金代码', '基金简称', '持仓日期', '股票代码', '股票名称', '数量', '市值', '市值占基金资产净值比', '第N大重仓股', '所在证券市场', '[所属国家(地区)]', '报告类型']}\n" \
                               "{'表名': '基金债券持仓明细', '字段名': ['基金代码', '基金简称', '持仓日期', '债券类型', '债券名称', '持债数量', '持债市值', '持债市值占基金资产净值比', '第N大重仓股', '所在证券市场', '[所属国家(地区)]', '报告类型']}\n" \
                               "{'表名': '基金可转债持仓明细', '字段名': ['基金代码', '基金简称', '持仓日期', '对应股票代码', '债券名称', '数量', '市值', '市值占基金资产净值比', '第N大重仓股', '所在证券市场', '[所属国家(地区)]', '报告类型']}\n" \
                               "{'表名': '基金日行情表', '字段名': ['基金代码', '交易日期', '单位净值', '复权单位净值', '累计单位净值', '资产净值']}\n" \
                               "{'表名': 'A股票日行情表', '字段名': ['股票代码', '交易日', '[昨收盘(元)]', '[今开盘(元)]', '[最高价(元)]', '[最低价(元)]', '[收盘价(元)]', '[成交量(股)]', '[成交金额(元)]']}\n" \
                               "{'表名': '港股票日行情表', '字段名': ['股票代码', '交易日', '[昨收盘(元)]', '[今开盘(元)]', '[最高价(元)]', '[最低价(元)]', '[收盘价(元)]', '[成交量(股)]', '[成交金额(元)]']}\n" \
                               "{'表名': 'A股公司行业划分表', '字段名': ['股票代码', '交易日期', '行业划分标准', '一级行业名称', '二级行业名称']}\n" \
                               "{'表名': '基金规模变动表', '字段名': ['基金代码', '基金简称', '公告日期', '截止日期', '报告期期初基金总份额', '报告期基金总申购份额', '报告期基金总赎回份额', '报告期期末基金总份额', '定期报告所属年度', '报告类型']}\n" \
                               "{'表名': '基金份额持有人结构', '字段名': ['基金代码', '基金简称', '公告日期', '截止日期', '机构投资者持有的基金份额', '机构投资者持有的基金份额占总份额比例', '个人投资者持有的基金份额', '个人投资者持有的基金份额占总份额比例', '定期报告所属年度', '报告类型']}\n\n" \
                               "请根据以下用户输入编写sql。\n用户输入: {user_question}"

BQA_CHAT_SQLRESULT_TEMPLATE = """问题:“{user_question}”。
答案:“{sql_result}”。

将问题的内容和答案的内容融合的文字内容输出。注意不要输出“问题:”或“答案:”。
"""


class BQAPromptGenerator(PromptGenerator):
    def __init__(self,
                 plan_template=BQA_PLAN_DEFAULT_PROMPT,
                 task_template=BQA_TASK_DEFAULT_PROMPT,
                 task_instruction_template=BQA_TASK_INSTRUCTION_TEMPLATE,
                 user_template=BQA_USER_QUESTION_TEMPLATE,
                 current_task_template=BS_CURRENT_TASK_TEMPLATE,
                 sep='\n\n',
                 prompt_max_length=10000):
        super().__init__(plan_template, task_template, task_instruction_template, user_template, current_task_template,
                         sep,
                         prompt_max_length)

    def generate(self, task_no=None):
        # init plan
        if task_no is None:
            prompt_list = [self.system_prompt,
                           self.user_prompt]
        # execute tasks
        else:
            # no task result
            if not self.task_result_prompt:
                prompt_list = [self.system_prompt,
                               self.user_prompt,
                               self.current_task_prompt]
            else:
                prompt_list = [self.system_prompt,
                               self.task_result_prompt,
                               self.user_prompt,
                               self.current_task_prompt]
        return self.sep.join(prompt_list)

    def update_task_prompt(self, current_task):
        self.current_task_prompt = self.current_task_template.replace("{current_task}", current_task)


class BQAChainPromptGenerator(PromptGenerator):
    def __init__(self,
                 chain_template=BQA_CHAIN_PROMPT,
                 task_instruction_template=BQA_TASK_INSTRUCTION_TEMPLATE,
                 user_template=BQA_USER_QUESTION_TEMPLATE,
                 sep='\n\n'):
        self.chain_template = chain_template
        self.task_instruction_template = task_instruction_template
        self.user_template = user_template
        self.sep = sep

    def init_prompt(self, tool_list):
        self.system_prompt = self.chain_template + self.task_instruction_template.replace("{tool_list}",
                                                                                          self.get_tool_str(tool_list))

    def generate(self, user_question):
        self.user_prompt = self.user_template.replace("{user_question}", user_question)
        return self.sep.join([self.system_prompt, self.user_prompt])

PromptGenerator类

上面的Prompt编写中的Prompt生成器的类定义,用于生成复杂的提示字符串,通过模板和用户输入生成适合不同场景的提示,并维护历史记录以确保提示长度不超过指定的最大长度。

完整代码:

import re


class PromptGenerator:

    def __init__(self,
                 plan_template: str = '',
                 task_template: str = '',
                 task_instruction_template: str = '',
                 user_template: str = '',
                 current_task_template: str = '',
                 sep='\n\n',
                 prompt_max_length: int = 10000):
        """
        prompt genertor
        Args:
            system_template (str, optional): System template, normally the role of LLM.
            instruction_template (str, optional): Indicate the instruction for LLM.
            user_template (str, optional): Prefix before user input. Defaults to ''.
            exec_template (str, optional): A wrapper str for exec result.
            assistant_template (str, optional): Prefix before assistant response.
            Some LLM need to manully concat this prefix before generation.
            prompt_max_length (int, optional): max length of prompt. Defaults to 2799.

        """

        self.plan_template = plan_template
        self.task_template = task_template
        self.task_instruction_template = task_instruction_template
        self.user_template = user_template
        self.current_task_template = current_task_template
        self.sep = sep

        self.prompt_max_length = prompt_max_length
        self.reset()

    def reset(self):
        self.prompt = ''

    def init_plan_prompt(self, user_question):
        """
        in this function, the prompt will be initialized.
        """
        self.system_prompt = self.plan_template
        self.user_prompt = self.user_template.replace("{user_question}",user_question)
        self.current_task_prompt = None
        self.task_result_prompt = None


    def init_task_prompt(self,user_question, tool_list):
        self.system_prompt = self.task_template + self.task_instruction_template.replace("{tool_list}",self.get_tool_str(tool_list))
        self.user_prompt = self.user_template.replace("{user_question}",user_question)
        self.current_task_prompt = self.current_task_template
        self.task_result_prompt = None

    def generate(self):
        """
        generate next round prompt based on previous llm_result and exec_result and update history
        """
        pass

    def get_tool_str(self, tool_list):
        """generate tool list string

        Args:
            tool_list (List[str]): list of tools

        """
        tool_str = self.sep.join(
            [f'{i+1}. {t}' for i, t in enumerate(tool_list)])
        return tool_str

    def get_history_str(self):
        """generate history string

        """
        history_str = ''
        for i in range(len(self.history)):
            history_item = self.history[len(self.history) - i - 1]
            text = history_item['content']
            if len(history_str) + len(text) + len(
                    self.prompt) > self.prompt_max_length:
                break
            history_str = f'{self.sep}{text.strip()}{history_str}'

        return history_str

谱聚类

介绍

谱聚类(Spectral Clustering)是一种基于图论的聚类方法,广泛应用于机器学习和数据挖掘领域。它利用数据点之间的相似性信息,通过谱图理论将数据嵌入到低维空间中,然后在这个低维空间中进行聚类。

基本步骤

  1. 构建相似性矩阵

    • 首先,计算数据集中每对点之间的相似性,生成一个相似性矩阵(Affinity Matrix)。相似性通常通过高斯核函数或k近邻方法来计算。
  2. 构建图拉普拉斯矩阵

    • 从相似性矩阵构建图拉普拉斯矩阵(Graph Laplacian)。图拉普拉斯矩阵有多种形式,常见的有未归一化的拉普拉斯矩阵和归一化的拉普拉斯矩阵。
  3. 计算拉普拉斯矩阵的特征值和特征向量

    • 对拉普拉斯矩阵进行特征值分解,得到前k个最小的特征值对应的特征向量。这些特征向量形成一个新的矩阵。
  4. 嵌入低维空间

    • 将数据点嵌入到由这些特征向量定义的低维空间中。
  5. 应用传统聚类算法

    • 在低维空间中,对嵌入的点应用传统的聚类算法(如k-means),得到最终的聚类结果。

例子理解

假设有一群人,把他们分成几个小组,使得同一组内的人彼此之间关系很近,而不同组的人关系较远。在一个派对上,我们让互相熟悉的人待在一起,而不熟的人分开。

我们的步骤类比于:

  • 观察所有人之间的关系
  • 用这些关系来建立一个图
  • 利用这个图找到隐藏的关系模式
  • 根据这些模式在低维空间中重新排列人们的位置
  • 最后根据新排列的位置,把人们分成几组,使得组内的人关系最近

我们可以使用谱聚类在复杂的关系网络中找到合适分组,让关系更加亲近的数据在一起。

总之谱聚类通过将数据点间的相似性信息转化为图结构,并在低维空间中进行聚类,实现了对复杂数据结构的有效分割和识别。

项目使用

我们要采用RAG-ICL的方法,需要构建样本库。通过谱聚类方法,可以对样本库中的数据进行更有效的组织和分类,从而提升检索效率和生成效果。

将1000个Questions中不含有公司名称的问题(即潜在的使用SQL查询进行回答 的问题),使用tokenizer统计词频,两两计算余弦相似度后进行谱聚类,类别个数75,可以看到每 类问题基本具有相同格式。从每类问题中随机选择2个,构成RAG-ICL的样本库,编写并校对 SQL语句和回答。

我们需要通过计算谱聚类来确定合适的类数量,使得样本库覆盖所有种类的问题。首先,我们可以使用一种方法来自动确定最佳聚类数。这里,我们可以使用“肘部法”或“轮廓系数”来找到最佳的聚类数。

1. 数据向量化

首先将文本数据转换为向量表示:

token_counts_list = questions_df['问题'].apply(tokenize_and_count).tolist()
all_tokens = set(token for token_counts in token_counts_list for token in token_counts)
vocabulary = list(all_tokens)
python
  • token_counts_list:对每个问题进行分词并计算每个词的出现频率,生成一个词频统计列表。
  • all_tokens:收集所有问题中的所有词,创建一个包含所有词的集合。
  • vocabulary:将所有词的集合转换为列表,作为词汇表。

2. 将词频转换为向量

将每个问题的词频统计转换为向量表示:

vectors = np.array([vectorize(token_counts, vocabulary) for token_counts in token_counts_list])
python
  • vectors:将每个问题的词频统计转换为向量表示,这样每个问题都对应一个高维向量。

3. 计算相似度矩阵

计算所有问题之间的相似度矩阵:

similarity_matrix = cosine_similarity(vectors)
python
  • similarity_matrix:使用余弦相似度计算所有问题向量之间的相似度,生成相似度矩阵。

4. 寻找最佳聚类数

定义一个函数,通过计算轮廓系数来寻找最佳聚类数:

def find_best_cluster_number(similarity_matrix, max_clusters=10):
    silhouette_scores = []
    for n_clusters in range(2, max_clusters + 1):
        clustering = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', assign_labels='kmeans')
        cluster_labels = clustering.fit_predict(similarity_matrix)
        silhouette_avg = silhouette_score(similarity_matrix, cluster_labels, metric="precomputed")
        silhouette_scores.append((n_clusters, silhouette_avg))

    best_n_clusters = max(silhouette_scores, key=lambda x: x[1])[0]
    return best_n_clusters
python
  • find_best_cluster_number:定义一个函数,使用谱聚类算法对不同的聚类数进行聚类,并计算每个聚类数下的轮廓系数(silhouette score)。
  • silhouette_scores:保存每个聚类数对应的轮廓系数。
  • best_n_clusters:找到使轮廓系数最大的聚类数,即最佳聚类数。

5. 找到最佳聚类数

调用函数,找到最佳聚类数:

best_num_clusters = find_best_cluster_number(similarity_matrix, max_clusters=20)
print(f"最佳聚类数: {best_num_clusters}")
python
  • best_num_clusters:找到最佳聚类数并打印。

6. 进行谱聚类

使用最佳聚类数进行谱聚类:

clustering = SpectralClustering(n_clusters=best_num_clusters, affinity='precomputed', assign_labels='kmeans')
labels = clustering.fit_predict(similarity_matrix)
questions_df['Cluster'] = labels
python
  • clustering:使用最佳聚类数初始化谱聚类模型,指定相似度矩阵作为预计算的相似度(affinity='precomputed'),使用 k-means 进行标签分配。
  • labels:对相似度矩阵进行聚类,得到每个问题的聚类标签。
  • questions_df['Cluster']:将聚类标签添加到原始数据框中。

通过以上步骤,谱聚类算法将相似的问题分配到同一个聚类中,从而实现对问题的分类。这样可以在后续处理中对每个聚类中的问题进行进一步分析或处理。

完整代码:

from flask import Flask, request, jsonify

app = Flask(__name__)


@app.route('/')
def hello_world():  # put application's code here
    return 'Hello World!'

@app.route('/ask', methods=['POST'])
def ask():
    data = request.get_json()
    question = data.get('question')
    # 在此处实现你的对话逻辑,例如返回一个简单的响应
    response = {"answer": "This is a placeholder answer."}
    return jsonify(response)


if __name__ == '__main__':
    app.run()


import pandas as pd
import numpy as np
from modelscope import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import SpectralClustering
from sklearn.metrics import silhouette_score
import random

csv_file = "questions.csv"
questions_df = pd.read_csv(csv_file)

model_dir = "path_to_your_model_directory"
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

def tokenize_and_count(text):
    tokens = tokenizer.tokenize(text)
    return dict((token, tokens.count(token)) for token in set(tokens))

def vectorize(token_counts, vocabulary):
    vector = [token_counts.get(token, 0) for token in vocabulary]
    return vector

token_counts_list = sql_questions_df['问题'].apply(tokenize_and_count).tolist()
all_tokens = set(token for token_counts in token_counts_list for token in token_counts)
vocabulary = list(all_tokens)
```python
将词频转换为向量
```python
vectors = np.array([vectorize(token_counts, vocabulary) for token_counts in token_counts_list])

similarity_matrix = cosine_similarity(vectors)


def find_best_cluster_number(similarity_matrix, max_clusters=10):
    silhouette_scores = []
    for n_clusters in range(2, max_clusters + 1):
        clustering = SpectralClustering(n_clusters=n_clusters, affinity='precomputed', assign_labels='kmeans')
        cluster_labels = clustering.fit_predict(similarity_matrix)
        silhouette_avg = silhouette_score(similarity_matrix, cluster_labels, metric="precomputed")
        silhouette_scores.append((n_clusters, silhouette_avg))

    best_n_clusters = max(silhouette_scores, key=lambda x: x[1])[0]
    return best_n_clusters

best_num_clusters = find_best_cluster_number(similarity_matrix, max_clusters=20)
print(f"最佳聚类数: {best_num_clusters}")
clustering = SpectralClustering(n_clusters=best_num_clusters, affinity='precomputed', assign_labels='kmeans')
labels = clustering.fit_predict(similarity_matrix)
sql_questions_df['Cluster'] = labels
sampled_questions = []
for cluster in range(best_num_clusters):
    cluster_questions = sql_questions_df[sql_questions_df['Cluster'] == cluster]
    if len(cluster_questions) > 2:
        sampled_questions.extend(cluster_questions.sample(n=2).to_dict('records'))
    else:
        sampled_questions.extend(cluster_questions.to_dict('records'))
for sample in sampled_questions:
    print(f"问题id: {sample['问题id']}, 问题: {sample['问题']}")
sampled_df = pd.DataFrame(sampled_questions)
sampled_df.to_csv("sampled_questions.csv", index=False)

对话模式

我们的模型目前是对所有问题进行读入,在经过模型运行后,将结果存入csv文件。我们要改造成可用的应用形式。

我们的设计页面可以参考GPT,毕竟GPT是大模型浪潮的始源,算是致敬。

设计思路

问题分类

对于每个输入的问题,我们将其分类,设置好Prompt后,将问题输入到模型。

如果模型回答包含招股说明书而没有股票数据库,说明是文本理解类,反之则是SQL类。

# 从用户输入获取问题
user_question = input("请输入问题:")

# 构建用于分类的提示
prompt = prompt_template.format(user_question)

# 调用模型进行分类
response_new, history_new = model.chat(tokenizer, prompt, history=None)

# 进行分类判断
if '招股说明书' in response_new and '股票数据库' not in response_new:
    question_class = 'Text'
elif '招股说明书' not in response_new and '股票数据库' in response_new:
    question_class = 'SQL'
    for company_name in company_list:
        if company_name in user_question:
            question_class = 'Text'
else:
    question_class = 'SQL'
    for company_name in company_list:
        if company_name in user_question:
            question_class = 'Text'

# 打印分类结果
print(f"问题: {user_question}")
print(f"分类: {question_class}")
print(f"模型回答: {response_new}")

生成SQL

利用已有的示例问题和SQL查询,通过相似度匹配找到与用户输入问题最相似的示例,然后基于这些示例生成SQL查询。这种方法利用了示例的多样性和丰富性,提高了生成SQL查询的准确性和可靠性。

具体流程:

  1. 数据库和模型设置

    • 连接到一个SQLite数据库,获取数据库表的结构信息。
    • 加载用于自然语言处理和SQL生成的模型和标记器。
  2. 禁用词列表

    • 定义一组禁用词,这些词在处理用户输入时会被过滤掉,以提高模型生成SQL的准确性。
  3. 加载SQL示例

    • 读取包含示例问题和对应SQL查询的CSV文件,并将这些示例存储在列表中,以便后续用于生成新的SQL查询。
    • 将每个示例问题转换为标记(tokens),并过滤掉禁用词。
  4. 用户输入处理

    • 从用户那里获取输入的自然语言问题。
    • 从问题中提取日期并替换为空格,然后将问题转换为标记,过滤掉禁用词。
    • 计算用户输入的问题与示例问题之间的相似度。
  5. 生成提示和SQL查询

    • 根据相似度选择最相似的若干个示例问题。
    • 生成一个包含这些示例问题和对应SQL查询的提示。
    • 使用生成的提示调用模型,生成与用户输入问题对应的SQL查询。
  6. 输出结果

    • 输出用户输入的问题、生成的SQL查询以及用于生成SQL查询的提示。
import re
import copy
from langchain.utilities import SQLDatabase
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# 数据库和模型设置
db0 = SQLDatabase.from_uri("sqlite:tcdata/bs_challenge_financial_14b_dataset/dataset/博金杯比赛数据.db", sample_rows_in_table_info=0)
dbd0 = db0.table_info
db2 = SQLDatabase.from_uri("sqlite:tcdata/bs_challenge_financial_14b_dataset/dataset/博金杯比赛数据.db", sample_rows_in_table_info=2)
dbd2 = db2.table_info

# 获取表信息
table_name_list = ['基金基本信息','基金股票持仓明细','基金债券持仓明细','基金可转债持仓明细','基金日行情表','A股票日行情表','港股票日行情表','A股公司行业划分表','基金规模变动表','基金份额持有人结构']
table_info_dict = {}
list1 = dbd2.split('CREATE TABLE')
for cyc_piece in range(len(list1)):
    list1[cyc_piece] = 'CREATE TABLE' + list1[cyc_piece]
for piece in list1:
    for word in table_name_list:
        if word in piece:
            table_info_dict[word] = piece

# 禁用词列表
deny_list = ['0','1','2','3','4','5','6','7','8','9',',','?','。',
             '一','二','三','四','五','六','七','八','九','零','十',
             '的','小','请','.','?','有多少','帮我','我想','知道',
             '是多少','保留','是什么','-','(',')','(',')',':',
             '哪个','统计','且','和','来','请问','记得','有','它们']
deny_token_list = []

# 模型目录
model_dir = '/tcdata/models/Tongyi-Finance-14B-Chat'
# 加载模型和标记器
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="cuda:0", trust_remote_code=True, bf16=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True, temperature=0.0001, top_p=1, do_sample=False, seed=1234)
print('B01_model_loaded')

# 获取禁用词的tokens
for word in deny_list:
    temp_tokens = tokenizer(word)
    temp_tokens = temp_tokens['input_ids']
    deny_token_list = deny_token_list + temp_tokens

# SQL示例文件
SQL_examples_file_dir = "/app/data/files/ICL_EXP.csv"
SQL_examples_file = pd.read_csv(SQL_examples_file_dir, delimiter=",", header=0)

# 初始化示例问题和SQL
example_question_list = SQL_examples_file['问题'].tolist()
example_sql_list = SQL_examples_file['SQL'].tolist()
example_token_list = []
for question in example_question_list:
    temp_tokens = tokenizer(question)
    temp_tokens = temp_tokens['input_ids']
    temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
    example_token_list.append(temp_tokens2)

def get_prompt_v33(question, index_list):
    Examples = '以下是一些例子:'
    for index in index_list:
        Examples += "问题:" + example_question_list[index] + '\n'
        Examples += "SQL:" + example_sql_list[index] + '\n'
    
    impt2 = """
        你是一个精通SQL语句的程序员。
        我会给你一个问题,请按照问题描述,仿照以下例子写出正确的SQL代码。
    """
    impt2 += Examples
    impt2 += "问题:" + question + '\n'
    impt2 += "SQL:"
    return impt2

# 获取用户输入的问题
user_question = input("请输入问题:")

# 初始化变量
response2 = 'N_A'
prompt2 = 'N_A'
pattern1 = r'\d{8}'
n = 5

# 处理问题
if user_question:
    date_list = re.findall(pattern1, user_question)
    temp_question2_for_search = user_question
    for t_date in date_list:
        temp_question2_for_search = temp_question2_for_search.replace(t_date, ' ')
    temp_tokens = tokenizer(temp_question2_for_search)
    temp_tokens = temp_tokens['input_ids']
    temp_tokens2 = [x for x in temp_tokens if x not in deny_token_list]
    temp_tokens = temp_tokens2
    
    # 计算与已有问题的相似度
    similarity_list = []
    for example_tokens in example_token_list:
        similarity = len(set(temp_tokens) & set(example_tokens)) / (len(set(temp_tokens)) + len(set(example_tokens)))
        similarity_list.append(similarity)
    
    # 求m个最大的数值及其索引
    t = copy.deepcopy(similarity_list)
    max_number = []
    max_index = []
    for _ in range(n):
        number = max(t)
        index = t.index(number)
        t[index] = 0
        max_number.append(number)
        max_index.append(index)
    t = []
    
    temp_length_test = ""
    short_index_list = []
    for index in max_index:
        temp_length_test_1 = temp_length_test
        temp_length_test += example_question_list[index]
        temp_length_test += example_sql_list[index]
        if len(temp_length_test) > 2300:
            break
        short_index_list.append(index)
    
    prompt2 = get_prompt_v33(user_question, short_index_list)
    response2, history = model.chat(tokenizer, prompt2, history=None)

# 输出结果
print(f"问题: {user_question}")
print(f"SQL语句: {response2}")
print(f"Prompt: {prompt2}")

根据SQL查询问题

我们总体思路是处理用户输入的SQL问题(由上面所得),通过计算问题与示例问题之间的相似度来生成合适的SQL查询结果和相应的回答。计算相似度在SQL查询后面进行,相似度计算可以基于具体的SQL查询结果进行,从而为生成的答案提供更好的上下文和参考。这样,模型可以利用最相关的示例问题生成更准确和相关的答案。

定义否定词列表
deny_list = ['0','1','2','3','4','5','6','7','8','9',',','?','。', '一','二','三','四','五','六','七','八','九','零','十', '的','小','请','.','?','有多少','帮我','我想','知道', '是多少','保留','是什么','-','(',')','(',')',':', '哪个','统计','且','和','来','请问','记得','有','它们']
pattern1 = r'\d{8}'

主要目的是过滤掉常见的、无意义的或者对相似度计算无帮助的词和符号,以提高问题之间相似度计算的准确性。这些否定词包括数字、标点符号以及一些常用的汉字和短语。

否定词列表很有必要,它有以下作用:

否定词列表的作用如下:

  1. 减少噪音:否定词通常是无意义的噪音词,对理解问题的核心内容没有帮助。去除这些词可以减少噪音,提高关键内容的突出性。

  2. 提高相似度计算的准确性:在计算问题之间的相似度时,常见词和符号可能会增加不相关内容的相似度。去除这些否定词可以确保相似度计算集中在有意义的词汇上,从而更准确地衡量两个问题的相似性。

  3. 优化模型输入:模型处理输入时,去除否定词可以减少不必要的计算负担,使模型更高效地处理有价值的信息。

通过这种方式,最终用于相似度计算的token列表只包含有意义的词汇,从而提高相似度计算的准确性和模型生成回答的质量。

定义提示生成函数
def get_prompt_v33(question, data, index_list):
    Examples = '以下是一些例子:'
    for index in index_list:
        Examples += f"问题:{example_question_list[index]}\n资料:{example_data_list[index]}\n答案:{example_FA_list[index]}\n"
    prompt = """
    你要进行句子生成工作,根据提供的资料来回答对应的问题。下面是一些例子。注意问题中对小数位数的要求。\n
    """ + Examples + f"问题:{question}\n资料:{data}\n答案:"
    return prompt

该函数用于生成模型的输入提示,包括示例问题、示例资料和示例答案。

读取处理示例问题数据
SQL_examples_file_dir = "/app/data/files/ICL_EXP.csv"
SQL_examples_file = pd.read_csv(SQL_examples_file_dir, delimiter=",", header=0)

example_question_list = SQL_examples_file['问题'].tolist()
example_data_list = SQL_examples_file['资料'].tolist()
example_FA_list = SQL_examples_file['FA'].tolist()
example_token_list = []

for question in example_question_list:
    temp_tokens = tokenizer(question)['input_ids']
    filtered_tokens = [x for x in temp_tokens if x not in deny_token_list]
    example_token_list.append(filtered_tokens)

从CSV文件读取示例问题、资料和答案,并将示例问题转化为token,存储在 example_token_list 中。

处理问题函数

def process_question(question, classification, SQL_search_result):
    if classification != 'SQL':
        return 'N_A', 'N_A'
    
    if not SQL_search_result or SQL_search_result == 'N_A':
        return 'SQL未能成功执行!', 'SQL未能成功执行!'

    if len(SQL_search_result) > 250:
        SQL_search_result = SQL_search_result[:250]
    
    date_list = re.findall(pattern1, question)
    question_for_search = question
    for date in date_list:
        question_for_search = question_for_search.replace(date, ' ')
    
    temp_tokens = tokenizer(question_for_search)['input_ids']
    filtered_tokens = [x for x in temp_tokens if x not in deny_token_list]

    similarity_list = [len(set(filtered_tokens) & set(example_tokens)) / (len(set(filtered_tokens)) + len(set(example_tokens))) for example_tokens in example_token_list]
    
    max_index = sorted(range(len(similarity_list)), key=lambda i: similarity_list[i], reverse=True)[:n]
    prompt = get_prompt_v33(question, SQL_search_result, max_index)
    response, _ = model.chat(tokenizer, prompt, history=None)
    
    return response, SQL_search_result

这个函数处理用户输入的问题,根据问题的分类和SQL结果,计算与示例问题的相似度,生成提示并调用模型生成答案。

处理输入问题
def handle_user_input(user_question, user_classification, user_SQL_result):
    temp_FA, SQL_search_result = process_question(user_question, user_classification, user_SQL_result)
    return {'question': user_question, 'FA': temp_FA, 'SQL_result': SQL_search_result}

调用 process_question 函数,返回问题、答案和SQL结果。

SSH通信

我们的容器跑在AutoDL平台,个人用户没有公网端口。想要访问到容器中的数据我们需要SSH隧道技术。

SSH隧道技术(SSH Tunneling)是一种通过安全外壳协议(SSH)将网络数据封装并通过安全通道传输的技术。它能够在不安全的网络环境中保护数据传输的安全性,常用于远程访问内部网络资源、加密数据流量和绕过网络限制。

隧道是在两个网络节点之间创建的一条虚拟通道,通过这条通道传输的数据会被加密,从而防止数据被截获或篡改。

我们通过SSH的登录指令和密码

通过AutoDL的ddl文件实现SSH隧道技术通信,成功实现访问容器数据,访问容器部署服务。

utils

post_process_sql_result.py

此段代码主要是对sql的result结果进行后处理,包括添加百分比符号、取整和调整小数位数。

  • 添加百分比符号:如果问题中提到费率、百分比或涨跌幅,而答案中没有百分比符号,则在答案中添加百分比符号。
  • 取整:如果问题中提到取整并且答案包含.0,则移除.0
  • 调整小数位数:根据问题中的要求,将答案中的小数位数调整到指定的位数。

整体流程

  • 读取输入文件

    • input = 'groundtruth_1124.json':输入文件名。
    • 使用json.loads读取JSON文件内容。
  • 处理每个问题和答案

    • 对每个数据项d中的问题和答案进行处理:
      • 移除答案中的引号、方括号和括号。
      • 检查问题是否包含费率、百分比或涨跌幅,必要时在答案中添加百分比符号。
      • 检查问题是否要求取整,并移除答案中的.0
      • 分割问题字符串,检查是否要求调整小数位数,并调用refine_xiaoshu函数调整答案。

实例

[
    {"q": "请给出这次测量的费率", "a": "0.1234"},
    {"q": "这个值保留两位小数", "a": "0.1234"},
    {"q": "涨跌幅是多少", "a": "0.5"}
]

输出可能为:

[
    {"q": "请给出这次测量的费率", "a": "0.1234%"},
    {"q": "这个值保留两位小数", "a": "0.12"},
    {"q": "涨跌幅是多少", "a": "0.5%"}
]

代码


import re
import json
p = "\d+\.\d+"
bd = '[=,.?!@#$%^&*()_+:"<>/\[\]\\`~——,。、《》?;’:“【】、{}|·!¥…()-]'
weishu_dict = {
    "1":1,
    "2":2,
    "3":3,
    "4":4,
    "5":5,
    "一":1,
    "二":2,
    "三":3,
    "四":4,
    "五":5,
    "两":2
}

def refine_xiaoshu(answer, ji):
    a = answer.strip()
    a_li = a.split(".")
    xiaoshu_part = re.match("\d+", a_li[-1]).group()
    new_xiaoshu_part = ""
    cur = 0
    while cur < weishu_dict[ji]:
        if len(xiaoshu_part) > cur:
            new_xiaoshu_part += xiaoshu_part[cur]
        else:
            new_xiaoshu_part += "0"
        cur += 1
    a_li[-1] = a_li[-1].replace(xiaoshu_part, new_xiaoshu_part)
    return ".".join(a_li)

def post_process_answer(qustion, answer):
    ans = answer.replace("\"", "")
    ans = ans.replace("[", "").replace("]", "").replace("(","").replace(")","")
    q = qustion
    if "费率" in q and len(ans.split(",")) == 1 and '%' not in ans:
        ans += '%'
    if "百分比" in q and len(ans.split(",")) == 1 and '%' not in ans:
        ans += '%'
    if "涨跌幅" in q and '%' not in ans:
        m = re.search(p, ans)
        if m:
            ans = ans.replace(m.group(), m.group()+'%')
    if "取整" in q and ".0" in ans:
        ans = ans.replace('.0', "")
    q_li = re.split(bd, q)
    for sub_q in q_li:
        if "小数" in sub_q:
            if "不超过" in sub_q:
                continue
            else:
                for ji in weishu_dict:
                    if ji+"位" in sub_q:
                        ans_list = ans.split(",")
                        if len(ans_list) == 1:
                            ans_list[-1] = refine_xiaoshu(ans,ji)
                        else:
                            for i,a0 in enumerate(ans_list):
                                if a0.find(".") > 0:
                                    ans_list[i] = refine_xiaoshu(a0,ji)
                                    break
                        ans = ", ".join(ans_list)
                        break
    return ans

if __name__ == '__main__':
    input = 'groundtruth_1124.json'
    output = 'sql_answer_1124.json'
    with open(input) as f:
        data = json.loads(f.read())
        xiaoshu = 0
        for d in data:
            ans = d["a"]
            ans = ans.replace("\"", "")
            ans = ans.replace("[", "").replace("]", "").replace("(","").repalce(")","")
            q = d["q"]
            if "费率" in q and len(ans.split(",")) == 1 and '%' not in ans:
                ans += '%'
            if "百分比" in q and len(ans.split(",")) == 1 and '%' not in ans:
                ans += '%'
            if "涨跌幅" in q and '%' not in ans:
                m = re.search(p, ans)
                if m:
                    ans = ans.replace(m.group(), m.group()+'%')
            if "取整" in q and ".0" in ans:
                ans = ans.replace('.0', "")
            q_li = re.split(bd, q)
            for sub_q in q_li:
                if "小数" in sub_q:
                    if "不超过" in sub_q:
                        xiaoshu += 1
                        print(ans)
                        continue
                    else:
                        for ji in weishu_dict:
                            if ji+"位" in sub_q:
                                xiaoshu += 1
                                ans_list = ans.split(",")
                                if len(ans_list) == 1:
                                    ans_list[-1] = refine_xiaoshu(ans)
                                else:
                                    for i,a0 in enumerate(ans_list):
                                        if a0.find(".") > 0:
                                            ans_list[i] = refine_xiaoshu(a0)
                                            break
                                ans = ", ".join(ans_list)
                                break

            d["a"] = ans
        with open(output, "w") as f1:
            f1.write(json.dumps(data, ensure_ascii=False))

前端页面代码

使用Flask集成到我们改造之后的对话系统中,主要操纵原生DOM完成工作

代码

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>基于金融知识</title>

    <!-- Bootstrap Stylesheet [ REQUIRED ] -->
    <link href="../static/css/bootstrap.min.css" rel="stylesheet">
    <link rel="shortcut icon" href="../static/images/山东大学校徽.jpg">
    <link href="../static/css/main.css" rel="stylesheet" type="text/css">
    <link href="../static/css/nifty.css" rel="stylesheet">
    <link href="../static/css/nifty-demo-icons.min.css" rel="stylesheet">
    <link href="../static/css/nifty-demo.min.css" rel="stylesheet">
    <link href="../static/css/font-awesome.min.css" rel="stylesheet">
    <link href="../static/css/wiki.css" rel="stylesheet">

    <!-- JAVASCRIPT -->
    <link href="../static/css/pace.min.css" rel="stylesheet">
    <script src="../static/js/pace.min.js"></script>
    <script src="../static/js/jquery-2.2.4.min.js"></script>
    <script src="../static/js/bootstrap.min.js"></script>
    <script src="../static/js/nifty.min.js"></script>
    <script src="../static/js/icons.js"></script>
    <script src="../static/js/echarts.min.js"></script>
    <script src="../static/js/nifty-demo.min.js"></script>

    <!-- Custom CSS -->
    <style>
        .centered-container {
            display: flex;
            justify-content: center;
            align-items: center;
            flex-direction: column;
        }

        .dialogue-box {
            width: 100%;
            max-width: 1200px;
            margin: 0 auto;
        }

        .col-l-t {
            height: 700px;
            overflow-y: auto;
            border: 1px solid #ddd;
            padding: 15px;
            border-radius: 5px;
            background-color: #f9f9f9;
        }

        .faq-input-content {
            display: flex;
            margin-top: 15px;
        }

        .faq-input-content textarea {
            flex-grow: 1;
            margin-right: 10px;
        }
    </style>
</head>

<body>
    <div id="container" class="effect aside-float aside-bright mainnav-lg">
        <!-- NAVBAR -->
        <header id="navbar">
            <div id="navbar-container" class="boxed">
                <div class="navbar-header">
                    <a href="{{ url_for('index') }}" class="navbar-brand">
                        <img src="../static/images/山东大学校徽.jpg" alt="Nifty Logo" class="brand-icon">
                        <div class="brand-title">
                            <span class="brand-text">llm研创</span>
                        </div>
                    </a>
                </div>
                <div class="navbar-content clearfix">
                    <ul class="nav navbar-top-links pull-left">
                        <li class="tgl-menu-btn">
                            <a class="mainnav-toggle" href="#">
                                <i class="demo-pli-view-list"></i>
                            </a>
                        </li>
                    </ul>
                    <ul class="nav navbar-top-links pull-left">
                        <h4>基于LLM的商业答疑系统</h4>
                    </ul>
                </div>
            </div>
        </header>
        <div class="boxed">
            <nav id="mainnav-container">
                <div id="mainnav">
                    <div id="mainnav-menu-wrap">
                        <div class="nano">
                            <div class="nano-content">
                                <ul id="mainnav-menu" class="list-group">
                                    <li class="list-divider"></li>
                                    <li class="list-header">问答系统</li>
                                    <li>
                                        <a href="{{ url_for('dialogue_page') }}">
                                            <i class="fa fa-comments" style="width:24px"></i>
                                            <span class="menu-title"><strong>多轮对话</strong></span>
                                        </a>
                                    </li>
                                </ul>
                            </div>
                        </div>
                    </div>
                </div>
            </nav>
            <div id="content-container" class="container-fluid">
                <div id="page-content" class="row centered-container">
                    <h4 class="text-main pad-btm bord-btm">对话系统</h4>
                    <div id="col-l" class="col-md-8 dialogue-box">
                        <div class="col-l-t">
                            <div class="content">
                                <div class="bubble" id="Chat">
                                    <!-- Initial Chat Content -->
                                </div>
                            </div>
                        </div>
                        <div class="col-l-b">
                            <div class="faq-input-content">
                                <label for="talkcontent"></label>
                                <textarea name="question" class="input form-control" id="talkcontent" placeholder="请输入你的问题(eg.湖南长远锂科股份有限公司变更设立时作为发起人的法人有哪些?)" autocomplete="off"></textarea>
                                <button id="sendbtn" class="btn btn-primary">发送</button>
                            </div>
                        </div>
                    </div>
                </div>
            </div>
        </div>
        <footer id="footer" class="text-center" style="margin-bottom: 0px">
           <p class="pad-lft">Copyright&#0169; 2024 llm小组  &nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;All Rights Reserved </p>
        </footer>
        <button class="scroll-top btn">
            <i class="pci-chevron chevron-up"></i>
        </button>
    </div>

    <!-- JavaScript for ChatSendClient() -->
    <script type="text/javascript">
    window.onload = function(){
        var input = document.getElementById('talkcontent');
        document.getElementById('sendbtn').onclick = function (){
            ChatSendClient();
        }
    }
    $(document).keypress(function (e) {
        if (e.which === 13) {
            ChatSendClient();
            event.returnValue = false;
        }
    });

    function ChatSendClient(){
        var text = $('#talkcontent').val(),
            $msgbox = $('#Chat'),
            sMesContent = '',
            aMesContent = '';

        if (text === ''){
            alert('请输入内容');
            return;
        } else {
            sMesContent = '<div class="msg fr"><span class="triangle right"></span><div class="article">' + text + '</div></div>';
            $msgbox.html($msgbox.html() + sMesContent);
            document.getElementById("talkcontent").value = "";
            $('.col-l-t').animate({ scrollTop: document.getElementById('Chat').scrollHeight + 'px' });

            // Add "Generating answer..." message
            var loadingMessage = '<div class="msg clearfix"><div class="user-assistant"></div><span class="triangle right"></span><div class="article">生成答案中...</div></div>';
            $msgbox.html($msgbox.html() + loadingMessage);
            $('.col-l-t').animate({ scrollTop: document.getElementById('Chat').scrollHeight + 'px' });

            $.getJSON('/dialogue_answer', { name: text }, function (result) {
                // Remove the loading message
                $('#Chat .msg').last().remove();
                // Add the answer
                aMesContent = '<div class="msg clearfix"><div class="user-assistant"></div><span class="triangle right"></span><div class="article">' + result.data.replace("\n","<br>") + '</div></div>';
                $msgbox.html($msgbox.html() + aMesContent);
                $('.col-l-t').animate({ scrollTop: document.getElementById('Chat').scrollHeight + 'px' });
            });
        }
    }
    </script>
</body>

</html>

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值