作者|huggingface 编译|VK 来源|Github
加载Google AI或OpenAI预训练权重或PyTorch转储
from_pretrained()
方法
要加载Google AI、OpenAI的预训练模型或PyTorch保存的模型(用torch.save()
保存的BertForPreTraining
实例),PyTorch模型类和tokenizer可以被from_pretrained()
实例化:
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs)
其中
BERT_CLASS
要么是用于加载词汇表的tokenizer(BertTokenizer
或OpenAIGPTTokenizer
类),要么是加载八个BERT或三个OpenAI GPT PyTorch模型类之一(用于加载预训练权重):BertModel
,BertForMaskedLM
,BertForNextSentencePrediction
,BertForPreTraining
,BertForSequenceClassification
,BertForTokenClassification
,BertForMultipleChoice
,BertForQuestionAnswering
,OpenAIGPTModel
,OpenAIGPTLMHeadModel
或OpenAIGPTDoubleHeadsModel
PRE_TRAINED_MODEL_NAME_OR_PATH
为:Google AI或OpenAI的预定义的快捷名称列表,其中的模型都是已经训练好的模型:
bert-base-uncased
:12个层,768个隐藏节点,12个heads,110M参数量。bert-large-uncased
:24个层,1024个隐藏节点,16个heads,340M参数量。bert-base-cased
:12个层,768个隐藏节点,12个heads,110M参数量。