def export_bert_onnx():
from transformers import AutoModel, AutoConfig, AutoTokenizer
path = '../data_path/models/roberta_hf'
bert_model = AutoModel.from_pretrained(path)
config = AutoConfig.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
from transformers.onnx.features import FeaturesManager
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')
input_ids = dummy_inputs['input_ids']
attention_masks = dummy_inputs['attention_mask']
token_type_ids = dummy_inputs['token_type_ids']
import torch
input_names = ["input_ids", "attention_masks", "token_type_ids"]
output_names = ["output"]
torch.onnx.export(bert_model, (input_ids, attention_masks, token_type_ids), 'bert_model.onnx', verbose=True,
input_names=input_names,
output_names=output_names, opset_version=10)
def export_pb(onnx_path):
from onnx_tf.backend import prepare
import onnx
import os
target_path = os.path.dirname(onnx_path)
print("加载模型:", onnx_path)
model = onnx.load(onnx_path)
print('导出pb模型:', target_path)
tf_model = prepare(model)
tf_model.export_graph(r'./pb/1')
print("导出pb完成!")
if __name__ == '__main__':
# export_bert_onnx()
export_pb('bert_model.onnx')
Bert模型导出为onnx和pb格式
最新推荐文章于 2024-08-13 08:16:12 发布