在AI应用中,确保生成文本的格式正确是非常重要的。LM Format Enforcer是一个可以强制语言模型输出格式的库,采用字符级解析器与标记器前缀树相结合的方法,允许只包含潜在有效格式字符序列的标记。本文将深入探讨LM Format Enforcer的核心原理,并结合具体代码示例,帮助大家更好地理解和应用这项技术。
技术背景介绍
语言模型(LLM)生成的内容往往难以确保其严格符合特定的格式要求,这在API调用、数据存储等场景中可能会带来问题。LM Format Enforcer通过过滤标记、结合字符级解析器与标记器前缀树,有效解决了这一问题。
核心原理解析
LM Format Enforcer通过解析字符序列并构建标记器前缀树,只保留符合特定格式的标记。这样,当生成文本时,输出的内容可以严格按照预期的格式呈现,从而减少格式错误和数据混乱的风险。
代码实现演示
环境配置
%pip install --upgrade --quiet lm-format-enforcer langchain-huggingface > /dev/null
设置模型
以下代码展示如何设置LLama2模型,初始化我们期望的输出格式。注意,LLama2模型的访问需要事先获得批准。
import logging
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from langchain_experimental.pydantic_v1 import BaseModel
logging.basicConfig(level=logging.ERROR)
class PlayerInformation(BaseModel):
first_name: str
last_name: str
num_seasons_in_nba: int
year_of_birth: int
model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cuda"
if torch.cuda.is_available():
config = AutoConfig.from_pretrained(model_id)
config.pretraining_tp = 1
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
)
else:
raise Exception("GPU not available")
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
基本模型输出
首先,我们不使用结构化解码方式来展示模型输出。
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
prompt = """Please give me information about {player_name}. You must respond using JSON format, according to the following schema:
{arg_schema}
"""
def make_instruction_prompt(message):
return f"[INST] <<SYS>>\n{DEFAULT_SYSTEM_PROMPT}\n<</SYS>> {message} [/INST]"
def get_prompt(player_name):
return make_instruction_prompt(
prompt.format(
player_name=player_name, arg_schema=PlayerInformation.schema_json()
)
)
hf_model = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200
)
original_model = HuggingFacePipeline(pipeline=hf_model)
generated = original_model.predict(get_prompt("Michael Jordan"))
print(generated)
使用LM Format Enforcer强制输出格式
from langchain_experimental.llms import LMFormatEnforcer
lm_format_enforcer = LMFormatEnforcer(
json_schema=PlayerInformation.schema(), pipeline=hf_model
)
results = lm_format_enforcer.predict(get_prompt("Michael Jordan"))
print(results)
批量处理
LM Format Enforcer还支持批量模式:
prompts = [
get_prompt(name) for name in ["Michael Jordan", "Kareem Abdul Jabbar", "Tim Duncan"]
]
results = lm_format_enforcer.generate(prompts)
for generation in results.generations:
print(generation[0].text)
使用正则表达式
LM Format Enforcer还支持使用正则表达式来过滤输出:
question_prompt = "When was Michael Jordan Born? Please answer in mm/dd/yyyy format."
date_regex = r"(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}"
answer_regex = " In mm/dd/yyyy format, Michael Jordan was born in " + date_regex
lm_format_enforcer = LMFormatEnforcer(regex=answer_regex, pipeline=hf_model)
full_prompt = make_instruction_prompt(question_prompt)
print("Unenforced output:")
print(original_model.predict(full_prompt))
print("Enforced Output:")
print(lm_format_enforcer.predict(full_prompt))
应用场景分析
LM Format Enforcer特别适用于需要输出严格格式化文本的场景,如API调用返回值、结构化数据存储等。通过结合字符解析和标记器前缀树,能够有效减少输出格式错误,提高数据质量。
实践建议
- 在实际应用中,确保明确定义输出格式的JSON Schema或正则表达式。
- 利用LM Format Enforcer的批量处理能力,提升生成效率。
- 为确保生成内容的准确性,必要时进行人工校对。
如果遇到问题欢迎在评论区交流。