今天需要用到transformer里面的bert进行mask预测,我这里分享一下我的代码:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
text = '[CLS] 我 是 [MASK] 国 人 [SEP]'
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Create the segments tensors.
segments_ids = [0] * len(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model.eval()
masked_index = tokenized_text.index('[MASK]')
# Predict all tokens
with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
predicted_index = torch.argmax(predictions[0][0][masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)
预测的结果:
.....
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
中
还不错吧,挺神奇的。transformer版本为:
transformers 3.0.2
参考文献
[1].predicting-missing-words-in-a-sentence-natural-language-processing-model. https://stackoverflow.com/questions/54978443/predicting-missing-words-in-a-sentence-natural-language-processing-model