model = BertClassifier()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
model.eval()
tokenizer = BertTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
token = tokenizer(text, add_special_tokens=True, padding='max_length', truncation=True, max_length=128)
input_ids = token['input_ids']
token_type_ids = token['token_type_ids']
attention_mask = token['attention_mask']
input_ids = torch.tensor([input_ids], dtype=torch.long)
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
jit_sample = (input_ids, attention_mask, token_type_ids)
module = torch.jit.trace(model, jit_sample)
module.save('pthmodes/model_jit.pt')
jit_predict = module(
input_ids,
attention_mask,
token_type_ids,
)
model_predict = model(
input_ids,
attention_mask,
token_type_ids,
)
pytorch TorchScript
最新推荐文章于 2024-04-24 19:32:27 发布