import transformers
print(transformers.__version__)
from transformers import BertTokenizer, BartForConditionalGeneration
bart_path="/home/xhsun/NLP/huggingfaceModels/Chinese/chinese-bart-base/"
tokenizer = BertTokenizer.from_pretrained(bart_path)
model = BartForConditionalGeneration.from_pretrained(bart_path)
我的版本是4.15.0
in_sentence='华盛顿是[MASK]的首都'
print(model.config.is_encoder_decoder)
input_ids = tokenizer.encode(in_sentence, return_tensors='pt')
num_beams=4
model.config.num_return_sequences=num_beams
pred_ids = model.generate(input_ids, num_beams=num_beams, max_length=20,return_dict_in_generate=True,output_scores=True)
- 默认num_beams是1,现在设置beam search的束宽是4
- model.config.num_return_sequences=num_beams的作用是返回num_beams条路径(默认只返回分数最高的)
- return_dict_in_generate=True,output_scores=True的作用是让模型同时返回路径和对应的分数。这两个参数默认都是False。
print(pred_ids.keys())
print(pred_ids['sequences'])
print(pred_ids['sequences_scores'])
print(in_sentence)
for output_path_ids,path_score in zip(pred_ids['sequences'],pred_ids['sequences_scores']):
print("Path ids: ",output_path_ids.tolist(),'Path score: ',path_score.item())
print(tokenizer.decode(output_path_ids,skip_special_tokens=True))
print('-------------------------------')
返回的num_beams个路径已经是按照“分数”排序的,这个“分数”是log后的值,取以e为底即可找到对应的概率
transformers所有生成模型共用一个generate方法,该方法写在generation_utils.py中,其它文件是没有generate方法的。
class GenerationMixin这个类里包含了所有需要用到的生成函数,同时该类混在PreTrainedModel内,所以所有继承PreTrainedModel的类都具有generate等生成方法