transformer库generate函数参数测试

第一篇文章想要记录一下自己在玩transformer库中模型"gpt2"进行生成文字过程中调整generate函数参数对生成的文字的影响。

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

prompt = "I am happy because"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# generate up to 30 tokens
outputs = model.generate(input_ids, do_sample=False, max_length=40,pad_token_id=tokenizer.eos_token_id)
output_text=tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(output_text)
1.do_sample(booloptional, defaults to False)

此时设置do_sample=False,这时使用的是贪心解码方式(greedy decoding)

打印如下

['I am happy because I am a good person. I am a good person. I am a good person. I am a good person. I am a good person. I am a good person.']

可以看出此时gpt2出现了严重的“复读机问题” ,但如果设置do_sample=True,则此时使用多项式采样的方式,生成如下文字:

["I am happy because it has been a challenge for me to make a great job. My life has been incredibly simple and I'm so grateful. I'm a student at a university where even though I"]

同时,多项式采样每一次生成的结果并不相同,而贪心解码一般是一样的。

2. temperature(floatoptional, defaults to 1.0) 

temperature参数是用于控制生成文本的随机性和多样性,其本质是调整了模型输出的logits概率分布。当temperature较高时,会更平均地分配概率给各个token,这导致生成的文本更具随机性和多样性;temperature较低接近0时,会倾向于选择概率最高的token,从而使生成的文本更加确定和集中。

注意,当使用采样方式时才可以使用该参数

当设置temperature=2.0时可以看出gpt2已经开始胡言乱语了

['I am happy because your comments help get my campaign launched at my level of fame so as to take more heat from me for my \'flab policy\' being more "savage":\n\nFirst']

当设置temperature=0.5时打印如下

['I am happy because I am a proud American. I am proud that I am a proud American. I am proud that I am a proud American. I am proud that I am a proud American.']

可以发现此时同样可能产生复读机问题

3.top_p (floatoptional, defaults to 1.0)

top_p与top_k类似,但是是使用累计概率。top_p的值越小,则可采样的词越少,会出现复读机问题。若top_p=0.5:

['I am happy because I have a good partner and we have a good family. We are happy because we have a good partner and we have a good family. We have a good family. We have']

若top_p=0.1:

['I am happy because I am a woman. I am a woman. I am a woman. I am a woman. I am a woman. I am a woman. I am a woman. I']
4.top_k (intoptional, defaults to 50)

该参数用于在生成下一个token时,限制模型只能考虑前k个概率最高的token。

若top_k=100:

['I am happy because there was no reason for us to be here," Karras said of his own life after passing his late father on.\n\nThe family in April shared an arrangement with E']

若top_k=10:

['I am happy because the people here are doing what it takes to win a championship in the NBA," said James, who will play for the Celtics. "I\'m a believer in basketball. I know']

看起来top_k=10时模型输出比top_k=100时的输出更加合理许多。

5.repetition_penalty (floatoptional, defaults to 1.0) 

这是一个重复惩罚的参数,用于缓解复读机现象。

上文观察到当设置top_p=0.1时会出现明显的复读机现象,下面进行测试

当设置top_p=0.1且repetition_penalty=2.0时:

['I am happy because I have a lot of friends who are very good at it.\n"It\'s not like they\'re going to be able, you know? They\'ll just go out and do']

可以看到重复惩罚是显著有效的

6.一些token

pad_token_id :padding token的id

bos_token_id :序列起始token的id

eos_token_id:序列结束token的id

7.其它

因为一些参数不好演示效果,在这里只做简单的记录

no_repeat_ngram_size:如果设置为 int > 0,则该大小的所有 n-gram的token 只能出现一次。

encoder_no_repeat_ngram_size :如果设置为 int > 0,则encoder_input_ids 中出现的该大小的所有n-gram的token都不能出现在decoder_input_ids 中。

bad_words_ids:不允许生成的令牌 ID 列表。 如果想要获取不应出现在生成文本中的单词的标记 ID,可以使用 tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids

force_words_ids:必须生成的令牌 ID 列表,与bad_word_ids刚好相反

浅发一下吧,内容很水,全当学习的记录了,自己也是刚开始入门nlp/大模型,希望和各位大佬多多交流

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小麦要吃麦当劳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值