【Hugging Face】transformers 库中 model.generate() 方法:自回归模型的文本生成方法

Hugging Face model.generate 方法

model.generatetransformers 库中的 文本生成(Text Generation) 方法,适用于 自回归模型(如 GPT-2、T5、BART、LLAMA),用于 生成文本、摘要、翻译、问答等


1. 适用于哪些模型?

generate 适用于 基于 Transformer 生成文本的模型,例如:

  • GPT-2 (AutoModelForCausalLM) → 文本生成
  • T5 / BART (AutoModelForSeq2SeqLM) → 翻译、摘要、填空
  • LLaMA / Mistral / Falcon (AutoModelForCausalLM) → 长文本生成
  • ChatGLM / Bloom / GPT-J (AutoModelForCausalLM) → 对话任务

不适用于:

  • BERT (AutoModelForMaskedLM)
  • RoBERTa (AutoModelForMaskedLM)

2. generate 的基本用法

2.1. 加载预训练模型

from transformers import AutoModelForCausalLM, AutoTokenizer

# 选择 GPT-2 模型
model_name = "gpt2"

# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

2.2. 生成文本

# 输入文本
input_text = "Hugging Face is"

# 编码输入
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# 生成文本
output_ids = model.generate(input_ids, max_length=50)

# 解码文本
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(output_text)

示例输出

Hugging Face is a company based in New York that specializes in AI and NLP...

3. generate 方法的关键参数

参数作用默认值
max_length生成文本的最大长度20
min_length生成文本的最小长度10
num_return_sequences生成多个结果1
temperature控制生成文本的随机性(越高越随机)1.0
top_k选择概率最高的 k 个词50
top_p只选择累计概率超过 p 的词(Nucleus Sampling)0.95
do_sample是否启用随机采样False
repetition_penalty惩罚重复生成的文本1.0
early_stopping是否提前停止False

4. generate 方法的高级用法

4.1. 生成多个文本

output_ids = model.generate(input_ids, max_length=50, num_return_sequences=3)
output_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]

for i, text in enumerate(output_texts):
    print(f"Generated {i+1}: {text}")

示例输出

Generated 1: Hugging Face is a leading AI company...
Generated 2: Hugging Face is developing new NLP tools...
Generated 3: Hugging Face is revolutionizing deep learning...

4.2. 控制文本的创造性

  • 提高 temperature(0.7 - 1.5)让文本更随机
model.generate(input_ids, max_length=50, temperature=1.2)
  • 设置 top_k 让文本更保守
model.generate(input_ids, max_length=50, top_k=10)
  • 设置 top_p(Nucleus Sampling)让文本更加自然
model.generate(input_ids, max_length=50, top_p=0.9)
  • 启用 do_sample=True 进行随机采样
model.generate(input_ids, max_length=50, do_sample=True, temperature=0.9)

4.3. 避免重复生成

如果模型重复生成相同的短语,可以使用 repetition_penalty(1.2-2.0)

model.generate(input_ids, max_length=50, repetition_penalty=1.5)

4.4. 使用 bad_words_ids 避免生成特定单词

bad_words = ["stupid", "bad"]
bad_word_ids = tokenizer(bad_words, add_special_tokens=False).input_ids

model.generate(input_ids, max_length=50, bad_words_ids=bad_word_ids)

4.5. 控制文本的长度

  • 限制最小/最大长度
model.generate(input_ids, min_length=30, max_length=100)
  • 使用 early_stopping 提前终止
model.generate(input_ids, max_length=50, early_stopping=True)

5. generate 在不同任务中的应用

5.1. 文本生成(GPT-2, LLaMA, Bloom)

model = AutoModelForCausalLM.from_pretrained("gpt2")
input_text = "Artificial intelligence is"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=50, do_sample=True, temperature=0.9)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)

5.2. 机器翻译(T5, MarianMT)

model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
input_text = "Hugging Face is a great company"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=50)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)  # "Hugging Face est une grande entreprise"

5.3. 摘要生成(BART, T5)

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
input_text = "Hugging Face is a company based in New York that specializes in NLP. It provides tools and libraries for building AI applications."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=30)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)  # "Hugging Face is an NLP company providing AI tools."

5.4. 问答任务(T5, BART)

model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-small-qa")
input_text = "question: What is Hugging Face? context: Hugging Face is an AI company specializing in NLP."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=20)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)  # "An AI company specializing in NLP"

6. 总结

  1. generate 是 Hugging Face 提供的文本生成方法,适用于 GPT、T5、BART 等模型
  2. 支持多种生成策略
    • temperature 控制创造性
    • top_k 限制候选词
    • top_p 选择最优词
    • do_sample=True 进行随机采样
  3. 适用于多种 NLP 任务
    • 文本生成(GPT-2, LLaMA)
    • 机器翻译(T5, MarianMT)
    • 摘要生成(BART, T5)
    • 问答任务(T5, BART)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值