在官网上下载好pytorch版的预训练模型
如果直接使用
model = BertForSequenceClassification.from_pretrained(model_name)
会出现如下报错。
OSError: Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.
解决方法:
config = BertConfig.from_json_file(model_name+'/bert_config.json')
config.num_labels = class_num
model = BertForSequenceClassification.from_pretrained(model_name,config=config)
首先使用BertConfig加载模型下的config文件,然后再使用下游任务模型加载。