Hugging Face model.generate
方法
model.generate
是 transformers
库中的 文本生成(Text Generation) 方法,适用于 自回归模型(如 GPT-2、T5、BART、LLAMA),用于 生成文本、摘要、翻译、问答等。
1. 适用于哪些模型?
generate
适用于 基于 Transformer 生成文本的模型,例如:
- GPT-2 (
AutoModelForCausalLM
) → 文本生成 - T5 / BART (
AutoModelForSeq2SeqLM
) → 翻译、摘要、填空 - LLaMA / Mistral / Falcon (
AutoModelForCausalLM
) → 长文本生成 - ChatGLM / Bloom / GPT-J (
AutoModelForCausalLM
) → 对话任务
不适用于:
- BERT (
AutoModelForMaskedLM
) - RoBERTa (
AutoModelForMaskedLM
)
2. generate
的基本用法
2.1. 加载预训练模型
from transformers import AutoModelForCausalLM, AutoTokenizer
# 选择 GPT-2 模型
model_name = "gpt2"
# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
2.2. 生成文本
# 输入文本
input_text = "Hugging Face is"
# 编码输入
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# 生成文本
output_ids = model.generate(input_ids, max_length=50)
# 解码文本
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)
示例输出
Hugging Face is a company based in New York that specializes in AI and NLP...
3. generate
方法的关键参数
参数 | 作用 | 默认值 |
---|---|---|
max_length | 生成文本的最大长度 | 20 |
min_length | 生成文本的最小长度 | 10 |
num_return_sequences | 生成多个结果 | 1 |
temperature | 控制生成文本的随机性(越高越随机) | 1.0 |
top_k | 选择概率最高的 k 个词 | 50 |
top_p | 只选择累计概率超过 p 的词(Nucleus Sampling) | 0.95 |
do_sample | 是否启用随机采样 | False |
repetition_penalty | 惩罚重复生成的文本 | 1.0 |
early_stopping | 是否提前停止 | False |
4. generate
方法的高级用法
4.1. 生成多个文本
output_ids = model.generate(input_ids, max_length=50, num_return_sequences=3)
output_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
for i, text in enumerate(output_texts):
print(f"Generated {i+1}: {text}")
示例输出
Generated 1: Hugging Face is a leading AI company...
Generated 2: Hugging Face is developing new NLP tools...
Generated 3: Hugging Face is revolutionizing deep learning...
4.2. 控制文本的创造性
- 提高
temperature
(0.7 - 1.5)让文本更随机
model.generate(input_ids, max_length=50, temperature=1.2)
- 设置
top_k
让文本更保守
model.generate(input_ids, max_length=50, top_k=10)
- 设置
top_p
(Nucleus Sampling)让文本更加自然
model.generate(input_ids, max_length=50, top_p=0.9)
- 启用
do_sample=True
进行随机采样
model.generate(input_ids, max_length=50, do_sample=True, temperature=0.9)
4.3. 避免重复生成
如果模型重复生成相同的短语,可以使用 repetition_penalty
(1.2-2.0)
model.generate(input_ids, max_length=50, repetition_penalty=1.5)
4.4. 使用 bad_words_ids
避免生成特定单词
bad_words = ["stupid", "bad"]
bad_word_ids = tokenizer(bad_words, add_special_tokens=False).input_ids
model.generate(input_ids, max_length=50, bad_words_ids=bad_word_ids)
4.5. 控制文本的长度
- 限制最小/最大长度
model.generate(input_ids, min_length=30, max_length=100)
- 使用
early_stopping
提前终止
model.generate(input_ids, max_length=50, early_stopping=True)
5. generate
在不同任务中的应用
5.1. 文本生成(GPT-2, LLaMA, Bloom)
model = AutoModelForCausalLM.from_pretrained("gpt2")
input_text = "Artificial intelligence is"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=50, do_sample=True, temperature=0.9)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)
5.2. 机器翻译(T5, MarianMT)
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
input_text = "Hugging Face is a great company"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=50)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text) # "Hugging Face est une grande entreprise"
5.3. 摘要生成(BART, T5)
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
input_text = "Hugging Face is a company based in New York that specializes in NLP. It provides tools and libraries for building AI applications."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=30)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text) # "Hugging Face is an NLP company providing AI tools."
5.4. 问答任务(T5, BART)
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-small-qa")
input_text = "question: What is Hugging Face? context: Hugging Face is an AI company specializing in NLP."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=20)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text) # "An AI company specializing in NLP"
6. 总结
generate
是 Hugging Face 提供的文本生成方法,适用于 GPT、T5、BART 等模型。- 支持多种生成策略:
temperature
控制创造性top_k
限制候选词top_p
选择最优词do_sample=True
进行随机采样
- 适用于多种 NLP 任务:
- 文本生成(GPT-2, LLaMA)
- 机器翻译(T5, MarianMT)
- 摘要生成(BART, T5)
- 问答任务(T5, BART)