Lawyer LLaMA(中文法律大模型本地部署)

Lawyer LLaMA(中文法律大模型本地部署)

1.模型选择(lawyer-llama-13b-v2
2.运行环境

​ 1.建议使用Python 3.8及以上版本。

​ 2.主要依赖库如下:

  • transformers >= 4.28.0 注意:检索模块需要使用transformers <= 4.30
  • sentencepiece >= 0.1.97
  • gradio
3.使用步骤

​ 1.从HuggingFace下载 **Lawyer LLaMA 2 (lawyer-llama-13b-v2)**模型参数。(需要的torch )

# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="pkupie/lawyer-llama-13b-v2")

2.从HuggingFace下载法条检索模块,并运行其中的python server.py启动法条检索服务,默认挂在9098端口。(注意事项,拉取的代码有可能少labels2id.pkl,pytorch_model.bin等文件)

​ 1.git lfs install

​ 2.git clone https://huggingface.co/pkupie/marriage_law_retrieval

​ 3.GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/pkupie/marriage_law_retrieval

​ 4.server.py代码这样的,模型路径手动更改

import json
import subprocess
import os
import codecs
import logging
import os
import math

import json
import random
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

from flask import Flask, request, jsonify
import json
import random
from tqdm import tqdm
import os
import pickle as pkl
from argparse import Namespace

from models import Elect

import torch
from transformers import AutoModel, AutoTokenizer

from sklearn.preprocessing import MultiLabelBinarizer

logger = logging.getLogger(__name__)

app = Flask(__name__)

hunyin_classifier = None

fatiao_args = Namespace()
fatiao_tokenizer = None
fatiao_model = None


@app.route('/check_hunyin', methods=['GET', 'POST'])
def check_hunyin():
    input_text = request.json['input'].strip()
    force_return = request.json['force_return'] if 'force_return' in request.json else False

    print("input_text:", input_text)

    if len(input_text) == 0:
        json_result = {
            "output": []
        }
        return jsonify(json_result)

    if not force_return:
        classifier_result = hunyin_classifier(input_text[:500])
        print(classifier_result)
        classifier_result = classifier_result[0]['label']

        # 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
        if '婚' in input_text:
            classifier_result = True

        # 如果不是婚姻相关的,直接返回空
        if classifier_result == False:
            json_result = {
                "output": []
            }
            return jsonify(json_result)

    inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
    batch = {
        'ids': inputs['input_ids'],
        'mask': inputs['attention_mask'],
        'token_type_ids': inputs["token_type_ids"]
    }
    model_output = fatiao_model(batch)
    pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
    pred_laws = []
    for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
        pred_laws.append({
            'id': law_id,
            'score': float(score),
            'text': fatiao_args.mlb.classes_[law_id]
        })

    json_result = {
        "output": pred_laws[:3]
    }

    print("json_result:", json_result)
    return jsonify(json_result)


if __name__ == "__main__":
    # 加载咨询分类模型,用于判断是否与婚姻有关
    hunyin_classifier_path = "C:/Users/win10/PycharmProjects/lawyer-llama_/marriage_law_retrieval/pretrained_models/roberta_wwm_ext_hunyin_2epoch/"

    # 检查模型文件是否存在
    model_file = os.path.join(hunyin_classifier_path, "pytorch_model.bin")

    # 打印目录内容
    print("Files in directory:")
    for filename in os.listdir(hunyin_classifier_path):
        print(filename)

    if not os.path.exists(model_file):
        print(f"Model file not found at {model_file}")
    else:
        print(f"Model file found at {model_file}")

    hunyin_config = AutoConfig.from_pretrained(
        hunyin_classifier_path,
        num_labels=2,
    )
    hunyin_tokenizer = AutoTokenizer.from_pretrained(
        hunyin_classifier_path
    )
    hunyin_model = AutoModelForSequenceClassification.from_pretrained(
        hunyin_classifier_path,
        config=hunyin_config,
    )
    hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)

    print("Model loaded successfully")

    # 加载法条检索模型
    fatiao_args.ckpt_dir = r"C:\Users\win10\PycharmProjects\lawyer-llama_\marriage_law_retrieval\pretrained_models\chinese-roberta-wwm-ext"
    fatiao_args.device = "cuda:0"

    # 确认路径是否正确
    labels2id_path = os.path.join("data", "labels2id.pkl")
    if not os.path.exists(labels2id_path):
        print(f"Labels2id file not found at {labels2id_path}")
    else:
        print(f"Labels2id file found at {labels2id_path}")

    with open(labels2id_path, "rb") as f:
        laws2id = pkl.load(f)
        fatiao_args.labels = list(laws2id.keys())

    id2laws = {}
    for k, v in laws2id.items():
        id2laws[v] = k
    print("法条个数:", len(id2laws))

    fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)

    fatiao_args.tokenizer = fatiao_tokenizer
    fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
    fatiao_model.eval()

    mlb = MultiLabelBinarizer()
    mlb.fit([fatiao_args.labels])
    fatiao_args.mlb = mlb

    with torch.no_grad():
        for idx, l in enumerate(fatiao_args.labels):
            text = ':'.join(l.split(':')[1:]).lower()
            la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
            ids = la_in['input_ids'].to(fatiao_args.device)
            mask = la_in['attention_mask'].to(fatiao_args.device)
            fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:, 0]).squeeze(0)

    fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
    fatiao_model.to(fatiao_args.device)

    logger.info("model loaded")
    app.run(host="0.0.0.0", port=9098, debug=False)

​ 5.如需使用nginx反向代理访问此服务,可参考https://github.com/LeetJoe/lawyer-llama/blob/main/demo/nginx_proxy.md (Credit to @LeetJoe

​ 1.启动命令 python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin” --offload_folder “C:/path/to/offload/folder”(内存不够时启动的命令在这个命令中,--offload_folder "C:/path/to/offload/folder" 用于指定一个目录,用来存储模型的部分数据,从而减轻内存负担。这通常是在处理大模型时的一种策略,通过将一些不常用的模型部分卸载到磁盘上,可以节省系统内存(RAM)的使用。)

​ 2.python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin”(内存够的时候启动命令

demo_web.py代码

import gradio as gr
import requests
import json
from transformers import LlamaForCausalLM, LlamaTokenizer, TextIteratorStreamer
import torch
import threading
import argparse

class StoppableThread(threading.Thread):
    """Thread class with a stop() method. The thread itself has to check
    regularly for the stopped() condition."""

    def __init__(self,  *args, **kwargs):
        super(StoppableThread, self).__init__(*args, **kwargs)
        self._stop_event = threading.Event()

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()

def json_send(url, data=None, method="POST"):
    headers = {"Content-type": "application/json", "Accept": "text/plain", "charset": "UTF-8"}
    try:
        if method == "POST":
            if data is not None:
                response = requests.post(url=url, headers=headers, data=json.dumps(data))
            else:
                response = requests.post(url=url, headers=headers)
        elif method == "GET":
            response = requests.get(url=url, headers=headers)
        response.raise_for_status()  # Ensure we notice bad responses
        return response.json()  # Return the response as a JSON object
    except requests.exceptions.RequestException as e:
        print(f"HTTP Request failed: {e}")
        return {}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=7860)
    parser.add_argument("--checkpoint", type=str, default="")
    parser.add_argument("--classifier_url", type=str, default="")
    parser.add_argument("--load_in_8bit", action="store_true")
    parser.add_argument("--offload_folder", type=str, default="./offload")
    args = parser.parse_args()
    checkpoint = args.checkpoint
    classifier_url = args.classifier_url

    print("Loading model...")
    tokenizer = LlamaTokenizer.from_pretrained(checkpoint)
    if args.load_in_8bit:
        model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", load_in_8bit=True, offload_folder=args.offload_folder)
    else:
        model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float16, offload_folder=args.offload_folder)
    print("Model loaded.")

    with gr.Blocks() as demo:
        chatbot = gr.Chatbot()
        input_msg = gr.Textbox(label="Input")
        with gr.Row():
            generate_button = gr.Button('Generate', elem_id='generate', variant='primary')
            clear_button = gr.Button('Clear', elem_id='clear', variant='secondary')

        def user(user_message, chat_history):
            user_message = user_message.strip()
            return "", chat_history + [[user_message, None]]

        def bot(chat_history):
            # extract user inputs from chat history and retrieve law articles
            current_user_input = chat_history[-1][0]

            if len(current_user_input) == 0:
                yield chat_history[:-1]
                return

            # 检索法条
            history_user_input = [x[0] for x in chat_history]
            input_to_classifier = " ".join(history_user_input)
            data = {"input": input_to_classifier}
            result = json_send(classifier_url, data, method="POST")
            retrieve_output = result.get('output', [])

            # 构造输入
            if len(retrieve_output) == 0:
                input_text = "你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。\n"
                for history_pair in chat_history[:-1]:
                    input_text += f"### Human: {history_pair[0]}\n### Assistant: {history_pair[1]}\n"
                input_text += f"### Human: {current_user_input}\n### Assistant: "
            else:
                input_text = f"你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。请参考给出的\"参考法条\",回复用户的咨询问题。\"参考法条\"中可能存在与咨询无关的法条,请回复时不要引用这些无关的法条。\n"
                for history_pair in chat_history[:-1]:
                    input_text += f"### Human: {history_pair[0]}\n### Assistant: {history_pair[1]}\n"
                input_text += f"### Human: {current_user_input}\n### 参考法条: {retrieve_output[0]['text']}\n{retrieve_output[1]['text']}\n{retrieve_output[2]['text']}\n### Assistant: "

            print("=== Input ===")
            print("input_text: ", input_text)

            inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
            streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
            generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=400, do_sample=False, repetition_penalty=1.1)
            thread = StoppableThread(target=model.generate, kwargs=generation_kwargs)
            thread.start()

            # 开始流式生成
            chat_history[-1][1] = ""
            for new_text in streamer:
                chat_history[-1][1] += new_text
                yield chat_history

            streamer.end()
            thread.stop()
            print("Output: ", chat_history[-1][1])

        input_msg.submit(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
            bot, [chatbot], chatbot
        )
        generate_button.click(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
            bot, [chatbot], chatbot
        )

    demo.queue()
    demo.launch(share=False, server_port=args.port, server_name='0.0.0.0')
  • 17
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值