pytorch加载bert权重与转换成onnx

3 篇文章 0 订阅

由于bert是google创造的模型,所以大部分都是用tensorflow编写。自从有了transformer库,pytorch版本的模型加载也简单了许多。
权重文件,如图所示:
请添加图片描述
config.json是bert的配置,包括hidden_size,drop此类超参,如下所示:

{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 21128
}

bin则是计算图和权重构成的2进制文件。

import os
import tempfile
import numpy as np
from onnxruntime import InferenceSession
import torch
from torch import nn
from transformers import BertPreTrainedModel, BertModel, BertForSequenceClassification

torch.set_grad_enabled(False)
class bert_model(BertPreTrainedModel):
    def __init__(self, config):
        super(bert_model, self).__init__(config)
        self.bert = BertForSequenceClassification(config)
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        return bert_output.logits


def export_to_onnx(task, model_dir, output_model_name):
    if task == 1:
        model = bert_model.from_pretrained(model_dir, num_labels=2)
        dummy_input = {
            "input_ids": torch.tensor([[101, 2769, 1372, 2682, 2127, 102, 0]]),
            "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 0]]),
            "token_type_ids": torch.tensor([[0, 0, 0, 0, 0, 0, 0]]),
        }
        dynamic_axes = {
            'input_ids': [0, 1],
            'attention_mask': [0, 1],
            'token_type_ids': [0, 1],
        }
        output_names = ['start_logits', 'end_logits']
    
    with tempfile.NamedTemporaryFile() as fp:
        torch.onnx.export(model,
                          args=tuple(dummy_input.values()),
                          f=fp,
                          input_names=list(dummy_input),
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          opset_version=10)
        sess = InferenceSession(fp.name)
        model.eval()
        if task == 1:
            old_start_logits, old_end_logits = model(**dummy_input.copy())
            new_start_logits, new_end_logits = sess.run(
                output_names=output_names,
                input_feed={key: value.numpy() for key, value in dummy_input.items()})
            np.testing.assert_almost_equal(old_start_logits.numpy(), new_start_logits, 5)
            np.testing.assert_almost_equal(old_end_logits.numpy(), new_end_logits, 5)

环境配置列表:
torch == 1.8.1
transformers == 4.6.1
onnxruntime == 1.8.0
加载只需要把bin文件与json合成一个文件夹,加载文件夹路径即可。

转换成onnx,由于输出有三个embeddings,torch.onnx.export中args使用tuple打包张量(tuple of arguments),input_names按顺序分配名称到图中的输入节点(list of strings)。

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
要将微软用PyTorch实现的GLIP模型转换为ONNX模型,并且要求ONNX模型能够直接加载使用,可以按照以下步骤进行操作: 1. 首先,加载微软用PyTorch实现的GLIP模型。可以根据具体的模型代码进行加载,这里以示例模型为例。 ```python import torch import torchvision.models as models # 加载微软GLIP模型 model = models.glip_model() ``` 2. 然后,定义GLIP模型的输入维度。根据GLIP模型的输入要求进行定义,这里以示例输入维度为(1, 3, 224, 224),表示一张三通道、分辨率为224x224的彩色图像。 ```python dummy_input = torch.randn(1, 3, 224, 224) ``` 3. 接下来,使用`torch.onnx.export()`函数将GLIP模型转换为ONNX格式的模型。在转换过程中,可以选择是否加载模型的权重,这里选择不加载权重。 ```python torch.onnx.export(model, dummy_input, "glip_model.onnx", do_constant_folding=False) ``` 4. 完成上述步骤后,将会生成一个名为"glip_model.onnx"的ONNX模型文件,可以直接加载和使用该模型。 综上所述,将微软用PyTorch实现的GLIP模型转换为ONNX模型并能够直接加载使用的步骤如下所示: ```python import torch import torchvision.models as models # 加载微软GLIP模型 model = models.glip_model() # 定义GLIP模型的输入维度 dummy_input = torch.randn(1, 3, 224, 224) # 将模型转换为ONNX格式但不加载权重 torch.onnx.export(model, dummy_input, "glip_model.onnx", do_constant_folding=False) ``` 通过以上步骤,您可以将微软用PyTorch实现的GLIP模型成功转换为ONNX模型,并能够直接加载和使用该模型。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [模型转换 PyTorchONNX 入门](https://blog.csdn.net/qq_41204464/article/details/129073729)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值