由于bert是google创造的模型,所以大部分都是用tensorflow编写。自从有了transformer库,pytorch版本的模型加载也简单了许多。
权重文件,如图所示:
config.json是bert的配置,包括hidden_size,drop此类超参,如下所示:
{
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": 21128
}
bin则是计算图和权重构成的2进制文件。
import os
import tempfile
import numpy as np
from onnxruntime import InferenceSession
import torch
from torch import nn
from transformers import BertPreTrainedModel, BertModel, BertForSequenceClassification
torch.set_grad_enabled(False)
class bert_model(BertPreTrainedModel):
def __init__(self, config):
super(bert_model, self).__init__(config)
self.bert = BertForSequenceClassification(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
return bert_output.logits
def export_to_onnx(task, model_dir, output_model_name):
if task == 1:
model = bert_model.from_pretrained(model_dir, num_labels=2)
dummy_input = {
"input_ids": torch.tensor([[101, 2769, 1372, 2682, 2127, 102, 0]]),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 0]]),
"token_type_ids": torch.tensor([[0, 0, 0, 0, 0, 0, 0]]),
}
dynamic_axes = {
'input_ids': [0, 1],
'attention_mask': [0, 1],
'token_type_ids': [0, 1],
}
output_names = ['start_logits', 'end_logits']
with tempfile.NamedTemporaryFile() as fp:
torch.onnx.export(model,
args=tuple(dummy_input.values()),
f=fp,
input_names=list(dummy_input),
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=10)
sess = InferenceSession(fp.name)
model.eval()
if task == 1:
old_start_logits, old_end_logits = model(**dummy_input.copy())
new_start_logits, new_end_logits = sess.run(
output_names=output_names,
input_feed={key: value.numpy() for key, value in dummy_input.items()})
np.testing.assert_almost_equal(old_start_logits.numpy(), new_start_logits, 5)
np.testing.assert_almost_equal(old_end_logits.numpy(), new_end_logits, 5)
环境配置列表:
torch == 1.8.1
transformers == 4.6.1
onnxruntime == 1.8.0
加载只需要把bin文件与json合成一个文件夹,加载文件夹路径即可。
转换成onnx,由于输出有三个embeddings,torch.onnx.export中args使用tuple打包张量(tuple of arguments),input_names按顺序分配名称到图中的输入节点(list of strings)。