pytorch 模型与tf模型转换

一 bert_model.ckpt转pytoch_model.bin

Transformers库也是也提供了相关代码,这里做个搬运工
convert_bert_original_tf_checkpoint_to_pytorch.py
参考文章:https://zhuanlan.zhihu.com/p/361300189

二 pytoch_model.bin转bert_model.ckpt

convert_pytorch_checkpoint_to_tf2.py

三 torch.jit【将torch文件转换成C++可调用的文件】

1.1 JIT是什么

JIT 是一种概念,全称是 Just In Time Compilation,中文译为「即时编译」,是一种程序优化的方法,一种常见的使用场景是「正则表达式」

TorchScript(PyTorch 的 JIT 实现)
TorchScript是Pytorch模型(继承自nn.Module)的中间表示,可以在像C++这种高性能的环境中运行。

用 JIT 将 Python 模型转换为 TorchScript Module

https://pytorch.org/docs/stable/generated/torch.jit.trace.html

1.2 JIT的好处

模型部署
PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…

模型可视化
TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 forward 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这两种方式在下文中将详细介绍。

1.3 TorchScript Module 的两种生成方式

  1. 编码(Scripting)
    可以直接使用 TorchScript Language 来定义一个 PyTorch JIT Module,然后用 torch.jit.script 来将他转换成 TorchScript Module 并保存成文件。而 TorchScript Language 本身也是 Python 代码,所以可以直接写在 Python 文件中。
    使用 TorchScript Language 就如同使用 TensorFlow 一样,需要前定义好完整的图。对于 TensorFlow 我们知道不能直接使用 Python 中的 if 等语句来做条件控制,而是需要用 tf.cond,但对于 TorchScript 我们依然能够直接使用 if 和 for 等条件控制语句,所以即使是在静态图上,PyTorch 依然秉承了「易用」的特性。TorchScript Language 是静态类型的 Python 子集,静态类型也是用了 Python 3 的 typing 模块来实现,所以写 TorchScript Language 的体验也跟 Python 一模一样,只是某些 Python 特性无法使用(因为是子集),可以通过 TorchScript Language Reference 来查看和原生 Python 的异同。
    理论上,使用 Scripting 的方式定义的 TorchScript Module 对模型可视化工具非常友好,因为已经提前定义了整个图结构。
  1. 追踪(Tracing)
    使用 TorchScript Module 的更简单的办法是使用 Tracing,Tracing 可以直接将 PyTorch 模型(torch.nn.Module)转换成 TorchScript Module。「追踪」顾名思义,就是需要提供一个「输入」来让模型 forward 一遍,以通过该输入的流转路径,获得图的结构。这种方式对于 forward 逻辑简单的模型来说非常实用,但如果 forward 里面本身夹杂了很多流程控制语句,则可能会有问题,因为同一个输入不可能遍历到所有的逻辑分枝。

此外,还可以混合使用上面两种方式。

import torch
from torch import nn
from torch import quantization
 
class ConvBnReluModel(nn.Module):
    def __init__(self) -> None:
        super(ConvBnReluModel, self).__init__()
        self.conv = nn.Conv2d(3,5,3, bias=False)
        self.bn = nn.BatchNorm2d(5)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
 
m = ConvBnReluModel()
m.eval()
layers = [['conv','bn','relu']]
f = quantization.fuse_modules(m,layers, inplace=True)
 
types_to_quantize = {nn.Conv2d, nn.BatchNorm2d, nn.ReLU}
q = quantization.quantize_dynamic(f, types_to_quantize, dtype=torch.qint8)
 
s = torch.jit.script(q)
torch.jit.save(s, 'quantize_model.pth')
from transformers import (
        BertModel, 
        BertTokenizer, 
        BertConfig, 
        AutoModelForSequenceClassification,
        AutoConfig,
        AutoTokenizer
)
import torch

config = AutoConfig.from_pretrained(
        data_dir + "config.json",
        num_labels=2,
        finetuning_task="",
        cache_dir=None,
        revision="main",
        use_auth_token=None,
)
tokenizer = AutoTokenizer.from_pretrained(
        data_dir,
        cache_dir=None,
        use_fast=True,
        revision="main",
        use_auth_token=None,
)
model = AutoModelForSequenceClassification.from_pretrained(
        data_dir,
        from_tf=False,
        config=config,
        cache_dir=None,
        revision="main",
        use_auth_token=None,
)


# Masking one of the input tokens
sentence1_list = ["query1"]
sentence2_list = ["query2"]
token_args = (sentence1_list, sentence2_list)
result = tokenizer(*token_args, padding="max_length", max_length=64, truncation=True)
masks = result["attention_mask"]
input_ids = result["input_ids"]
segments_ids = result["token_type_ids"]
#masks = [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
#input_ids = [[101, 2458, 1912, 2542, 4384, 102, 2193, 5661, 1068, 7308, 2193, 5661, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
#segments_ids = [[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]


# Creating a dummy input
tokens_tensor = torch.tensor(input_ids)
atterntion_tensor = torch.tensor(masks)
token_type_tensors = torch.tensor(segments_ids)
dummy_input = [tokens_tensor, atterntion_tensor, token_type_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.

# Instantiating the model
model.eval()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
#model = BertModel.from_pretrained("/nfs/volume-1070-2/yushu/D1_smi/model_out/d1_lcqmc_15_model_64_200_3e", torchscript=True)



# Creating the trace
print("model:")
print(model)
#traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors], strict=False)
traced_model = torch.jit.trace(model, [tokens_tensor, atterntion_tensor, token_type_tensors])
torch.jit.save(traced_model, "traced_bert_classify.pt")

#loaded_model = torch.jit.load("./pt_model/D1_similar_model/D1_similar_model.pt")
loaded_model = torch.jit.load("traced_bert_classify.pt")
loaded_model.eval()

print(loaded_model)
#all_encoder_layers, pooled_output = loaded_model(*dummy_input)

pooled_output = loaded_model(*dummy_input)
#print(all_encoder_layers)
print("***********")
print(pooled_output)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值