transformer调用t5模型过程的代码
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
注意,如果使用中文版的情况下,一定要调用mt5模型
import torch
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
tokenizer = MT5Tokenizer.from_pretrained('/home/xiaoguzai/模型/mt5')
model = MT5ForConditionalGeneration.from_pretrained('/home/xiaoguzai/模型/mt5')
input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
input_ids = torch.tensor([[1,2,3,4,5]])
outputs = model.generate(input_ids)