使用LM Format Enforcer强制语言模型输出格式

在AI应用中,确保生成文本的格式正确是非常重要的。LM Format Enforcer是一个可以强制语言模型输出格式的库,采用字符级解析器与标记器前缀树相结合的方法,允许只包含潜在有效格式字符序列的标记。本文将深入探讨LM Format Enforcer的核心原理,并结合具体代码示例,帮助大家更好地理解和应用这项技术。

技术背景介绍

语言模型(LLM)生成的内容往往难以确保其严格符合特定的格式要求,这在API调用、数据存储等场景中可能会带来问题。LM Format Enforcer通过过滤标记、结合字符级解析器与标记器前缀树,有效解决了这一问题。

核心原理解析

LM Format Enforcer通过解析字符序列并构建标记器前缀树,只保留符合特定格式的标记。这样,当生成文本时,输出的内容可以严格按照预期的格式呈现,从而减少格式错误和数据混乱的风险。

代码实现演示

环境配置

%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null

设置模型

以下代码展示如何设置LLama2模型,初始化我们期望的输出格式。注意,LLama2模型的访问需要事先获得批准。

import logging
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from langchain_experimental.pydantic_v1 import BaseModel

logging.basicConfig(level=logging.ERROR)

class PlayerInformation(BaseModel):
    first_name: str
    last_name: str
    num_seasons_in_nba: int
    year_of_birth: int

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
    )
else:
    raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

基本模型输出

首先,我们不使用结构化解码方式来展示模型输出。

from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

prompt = """Please give me information about {player_name}. You must respond using JSON format, according to the following schema:

{arg_schema}

"""

def make_instruction_prompt(message):
    return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"

def get_prompt(player_name):
    return make_instruction_prompt(
        prompt.format(
            player_name=player_name, arg_schema=PlayerInformation.schema_json()
        )
    )

hf_model = pipeline(
    "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)

original_model = HuggingFacePipeline(pipeline=hf_model)

generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)

使用LM Format Enforcer强制输出格式

from langchain_experimental.llms import LMFormatEnforcer

lm_format_enforcer = LMFormatEnforcer(
    json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)

批量处理

LM Format Enforcer还支持批量模式:

prompts = [
    get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
    print(generation[0].text)

使用正则表达式

LM Format Enforcer还支持使用正则表达式来过滤输出:

question_prompt = "When was Michael Jordan Born? Please answer in mm/dd/yyyy format."
date_regex = r"(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}"
answer_regex = " In mm/dd/yyyy format, Michael Jordan was born in " + date_regex

lm_format_enforcer = LMFormatEnforcer(regex=answer_regex, pipeline=hf_model)

full_prompt = make_instruction_prompt(question_prompt)
print("Unenforced output:")
print(original_model.predict(full_prompt))
print("Enforced Output:")
print(lm_format_enforcer.predict(full_prompt))

应用场景分析

LM Format Enforcer特别适用于需要输出严格格式化文本的场景,如API调用返回值、结构化数据存储等。通过结合字符解析和标记器前缀树,能够有效减少输出格式错误,提高数据质量。

实践建议

  1. 在实际应用中,确保明确定义输出格式的JSON Schema或正则表达式。
  2. 利用LM Format Enforcer的批量处理能力,提升生成效率。
  3. 为确保生成内容的准确性,必要时进行人工校对。

如果遇到问题欢迎在评论区交流。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值