【电商】基于LangChain框架将多模态大模型连接数据库实现精准识别

1. LangChain框架

LangChain是一个用于构建基于大语言模型的应用框架,通过模块化设计简化了LLM与外部工具,数据源和复杂逻辑的集成。

连接能力

将多个LLM调用,工具调用或者数据处理步骤串联成工作流

数据感知

外部数据集成

支持连接数据库,API,解决LLM的知识截止问题

from langchain_community.document_loaders import CSVLoader
loader = CSVLoader(你的文件路径)

记忆管理

自动跟踪多轮对话历史,支持短期(内存,调包)或者长期(数据库)存储

from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory()

langchain支持本地模型

2. 多模态大模型连接数据库初始化设置

# 插入数据
def insert_data(connection, dataset):
    connection.execute(users.insert(), dataset)

def select_data(connection):
    result = connection.execute(users.select())
    for row in result:
        print(row)
def get_table_schema(engine):
    inspector = reflection.Inspector.from_engine(engine)
    table_names = inspector.get_table_names()
    schema = {}
    for table_name in table_names:
        columns = inspector.get_columns(table_name)
        schema[table_name] = [column['name'] for column in columns]
    return schema

def execute_query(query):
    with engine.connect() as conn:
        result = conn.execute(text(query))
        return result.fetchall()
    
def query_database(prompt, schema):
    # 将表结构信息包含在提示中
    schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])
    full_prompt = f"""
    以下是数据库的表结构信息:
    {schema_info}
    请根据图片信息生成一个SQL查询
    请严格按照表结构生成SQL查询在</answer>里面显示
    """
    prompt_final = PromptTemplate(
        input_variables=[schema_info],  
        template=full_prompt
    )
    print("***************************************************************************")
    print("full_prompt:", full_prompt.replace('\n', ''))
    print("***************************************************************************")
    return full_prompt.replace('\n', '')

3. 输出结果

在这里插入图片描述

多模态推理过程

message_search = [
            # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {
            "role": "user",
            "content": [
                {
                    "type": "image", 
                    "image": f"file://{image}"
                },
                {
                    "type": "text",
                    "text": query_database()
                }
            ]
        }]
    messages_prompt.append(message_search)
    text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages_prompt]
    print("*************************************")
    print("text:", text)
    print("*************************************")
    image_inputs, video_inputs = process_vision_info(messages_prompt)
    print("*************************************")
    print("image_inputs:", image_inputs)
    print("*************************************")
    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    print("*************************************")
    print("inputs:", inputs)
    print("*************************************")
    inputs = inputs.to("cuda:0")
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=300, do_sample=False)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    batch_output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    all_outputs.extend(batch_output_text)
    print("==========查询结果")
    print(all_outputs[0])

4. 编写获取sql语句函数

def extract_sql_answer(content):
    answer_tag_pattern = r'<Answer>(.*?)</Answer>'
    sql_pattern = r'```sql(.*?)```'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    if content_answer_match:
        content_answer = content_answer_match.group(1).strip()
        sql_match = re.search(sql_pattern, content_answer, re.DOTALL)
        if sql_match:
            sql_content = sql_match.group(1).strip()
            posibble_things = execute_query(sql_content)
            return posibble_things
    elif "sql" in content:
        sql_match = re.search(sql_pattern, content, re.DOTALL)
        if sql_match:
            sql_content = sql_match.group(1).strip()
            posibble_things = execute_query(sql_content)
            return posibble_things
    return ""

5. 将query查询语句与数据库连接并返回查询结果

def execute_query(query):
    query = query.replace("商品表", "users")
    with engine.connect() as conn:
        result = conn.execute(text(query))
        return result.fetchall()

6. 结果展示

在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值