如何使用大模型高效生产数据[含完整代码]

大模型出现之前我们的训练数据大都依赖人工标注、开源数据以及从线上数据中构造合适的监督数据,如果开源数据不太符合我们的业务需求(大部分情况下无法直接满足要求),且已有的线上数据也没办法抽取出符合要求的监督数据,这个时候恐怕只能依赖于人工标注了,但是人工标注又非常的耗费人力和时间。大模型出现后给我们提供了新的选择,我们可以通过构造高质量的prompt使用大模型给我们生产数据。原理其实很简单,所以本次分享的重点其实不在于原理,主要是想将本人工作中经常使用的一套代码分享出来,供大家直接使用

完整代码见:Data Generate Template | LlamaFactory

大致流程

当接到一个业务需求时跟产品对齐细节后就可以开始写prompt了(这里假设硬件资源支撑不了满足效果的例如72b模型,1.5b以及7b等模型直接用效果又无法达标)。我会先用vllm将效果好的大模型部署起来方便使用openai的sdk调用,反复调试迭代prompt差不多达到要求后我们就可以开始生产用于训练小模型的数据了。一般使用多进程加速数据生产。下面从代码层面讲讲具体的细节。

vllm部署大模型

vllm提供了非常方便的命令行部署命令:

CUDA_VISIBLE_DEVICES="0,1" python -m vllm.entrypoints.openai.api_server --served-model-name model_name  --model model_path --tensor-parallel-size 2 --port 8002

假设你的启动过程十分顺利,这时候你在终端就能看见打印出来的访问地址,一般是http://0.0.0.0:8002,这个时候你在浏览器中输入http://0.0.0.0:8002/docs就可以访问到一个可交互的文档界面,可以在这里尝试访问服务,看看是否可以正常调用。

调试prompt

vllm非常贴心的提供了一个基于gradio的示例代码供大家使用,调试prompt则会更加方便,代码在这里大家可以自取。但是我自己做了一点修改,可以在界面上直接修改prompt,这样就不用每次修改prompt后重启服务了,相对来说方便一点。修改后的版本如下:

import argparse
from collections.abc import Generator

import gradio as gr
from openai import OpenAI

# Argument parser setup
parser = argparse.ArgumentParser(
    description="Chatbot Interface with Customizable Parameters"
)
parser.add_argument(
    "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
)
parser.add_argument(
    "-m", "--model", type=str, default="gpt-3.5-turbo", help="Model name for the chatbot"
)
parser.add_argument(
    "--temp", type=float, default=0.8, help="Temperature for text generation"
)
parser.add_argument(
    "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
)
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)

# Parse the arguments
args = parser.parse_args()

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url

# Create an OpenAI client to interact with the API server
client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


def predict(
    message: str, history: list[tuple[str, str]], system_message: str
) -> Generator[str, None, None]:
    # Convert chat history to OpenAI format
    history_openai_format = [{"role": "system", "content": system_message}]

    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    # Create a chat completion request and send it to the API server
    stream = client.chat.completions.create(
        model=args.model,  # Model name to use
        messages=history_openai_format,  # type: ignore  # Chat history
        # temperature=args.temp,  # Temperature for text generation
        stream=True,  # Stream response
        extra_body={
            "repetition_penalty": 1,
            "stop_token_ids": (
                [int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()]
                if args.stop_token_ids
                else []
            ),
        },
        max_tokens=2048,
    )

    # Read and return generated text from response stream
    partial_message = ""
    for chunk in stream:
        partial_message += chunk.choices[0].delta.content or ""  # type: ignore
        yield partial_message


# Create and launch a chat interface with Gradio
gr.ChatInterface(
    predict,
    additional_inputs=[
        gr.Textbox("you are a helpful assistant", label="System Prompt"),
    ],
    additional_inputs_accordion=gr.Accordion(open=True),
).queue().launch(server_name=args.host, server_port=args.port, share=True)

服务启动过程中可能出现下面的信息,问题其实不解决也行,只是不能够分享给他人使用了,在自己本地上访问是没问题的。但由于我可能需要给到产品去体验效果,所以把这个问题修复了下。

Please check your internet connection. This can happen if your antivirus software blocks the download of this file. You can install manually by following these steps: 

1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_darwin_arm64
2. Rename the downloaded file to: frpc_darwin_arm64_v0.2
3. Move the file to this location: /xxx/.venv/lib/python3.12/site-packages/gradio

修复上面问题的具体步骤如下:

wget https://cdn-media.huggingface.co/frpc-gradio-0.2/frpc_darwin_arm64
mv frpc_darwin_arm64 frpc_darwin_arm64_v0.2
chmod +x frpc_darwin_arm64_v0.2
mv frpc_darwin_arm64_v0.2 you_gradio_path_in_env

再次启动时,就会打印出两个地址,如下。第二个地址可以分享给其他人访问

Running on local URL:  http://127.0.0.1:8001
Running on public URL: https://24e925b09b9a9c337d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)

我这边使用deepseek作为示例,打开地址后就可以看到如下对话界面。此时你可以快速地在下方的System Prom中迭代,在submit输入框中输入业务数据获得模型输出,查看输出是否符合要求。

这一步完成后我们应该在超大模型的基础上获得了一个不错的效果,能够满足业务要求。

大规模蒸馏数据

这一步其实主要是将前面启对话界面的代码改写成读取待标注的数据,并使用多进程调用vllm启动的服务,主要代码如下:

# 读取输入数据
df = pd.read_json(CONFIG["INPUT_FILE"], lines=True)
# 并行处理数据
with ProcessPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
    list(tqdm(executor.map(process_row, df.to_dict(orient="records")), total=len(df)))

process_row用于单次调用处理一条数据,具体实现可参考如下代码。主要逻辑是请求大模型获取标注结果,并保存每一条结果(当数据量较大时,这么做容错性比较高,不至于程序出错就会全部都需要重新标注),同时可以在post_process实现一定的后处理。

def process_row(row):
    try:
        user_input = USER_INPUT_TEMPLATE.format(**row)
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": user_input},
        ]
        response = (
            client.chat.completions.create(
                model=CONFIG["MODEL_NAME"],
                messages=messages,
            )
            .choices[0]
            .message.content
        )

        post_process(row, response)
    except Exception as e:
        print(f"处理数据时出错: {e}")
        print(f"跳过数据: {row.get('id', 'unknown')}")


def post_process(row, response):
    # 在此处理模型的响应,例如输出是json,可使用json.loads(response)
    # 示例:将响应直接添加到row中
    row["model_response"] = response

    # 生成唯一ID并保存处理后的数据
    unique_id = str(uuid.uuid4())
    filename = f"{CONFIG['PROCESSED_DIR']}/{unique_id}.json"
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(row, f, ensure_ascii=False, indent=4)

整体的生成时间可以通过tqdm较好的把控观察,一般来说几十万条数据两三天就够了,不过具体还是要看你的任务数据长度。

小结

完整的代码可以在这里这里找到。以上就是本人在工作中最常用的蒸馏数据的代码,不一定是最佳实践,但是目前对于我来说够用了。本文只着眼于如何高效的产出数据,略去了具体的一些细节;例如prompt如何迭代优化、不同的场景如何生产出优质的数据,后面可以单开一篇举例聊聊。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值