Pytorch模型转Tensorflow模型的那些事


最近在研究Query2Title模型,学术界上快速实验一般都用pytorch,但是业界部署模型上大多都还是tensorflow模型部署。也可能是自己太懒了,哈哈,pytorch用久了tensorflow有点生疏,懒的去用啦!结果就…,线上部署必须得tensorflow模型,所以就把torch的.pt文件,转成tensorflow的pb文件把,中间遇到一些问题记录下来,仅供借鉴参考!<ps!立下flag,坚持写一些博客>

pytorch模型转tensorflow流程

总体上pytorch转tensorflow遵循2个步骤:

  1. torch模型文件.pt保存成onnx模型文件.onnx
  2. .onnx模型文件转换成tensorflow .pb文件

torch模型文件转onnx文件

这里直接利用torch的onnx库将.pt文件load后导出为.onnx文件即可,基本上没有什么坑,如下:

import tensorflow as tf
import torch
import onnx
from onnx_tf.backend import prepare
import os
import numpy as np


def torch2Onnx():
    """
    pytorch转onnx
    """
    model_pytorch = YourModel()
    model_pytorch.load_state_dict(torch.load('xxx.pt',
                                       map_location='cpu'))
    # 输入placeholder
    dummy_input = torch.randint(0, 10000, (1, 20))
    dummy_output = model_pytorch(dummy_input)
    print(dummy_output.shape)

    # Export to ONNX format
    torch.onnx.export(model_pytorch, 
                      dummy_input, 
                      'model.onnx', 
                      input_names=['inputs'], 
                      output_names=['outputs'])
    

需要注意的是:tensorflow1.x模型是需要先定义计算图的,计算图需要输入输出节点,所以最好指定计算图的输入输出节点名称,以便后续调用tensorflow模型

.onnx文件转tensorflow .pb文件

这一步利用开源工具onnx-tf (github) 将.onnx文件转换成.pb文件。可以直接源码安装,也可以直接 pip安装,转换过程如下:

def onnx2Tensorflow(onnx_model="model.onnx", tf_model="model.tf.pb"):
    """
    onnx转tensorflow
    """
    """
    # ---------------------报错-------------------------
    # Load ONNX model and convert to TensorFlow format
    model_onnx = onnx.load(onnx_model)
    tf_rep = prepare(model_onnx)
    # Export model as .pb file
    tf_rep.export_graph(tf_model)
    """
    os.system("onnx-tf convert -i %s -o %s"%(onnx_model, tf_model))

官方教程是按照注释中的内容做的,但是我自己执行报错,但是命令行方式转换可以成功。

导入计算图,测试样例

转换成功后,我们加载计算图测试一下样例输出是否正确,如下:

def load_pb(path_to_pb):
    """
    加载pb文件
    """
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        # 定义计算图
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

def loadTFModel():
    """
    加载tensorflow模型测试
    """
    graph = load_pb('model.tf.pb')
    with tf.Session(graph=graph) as sess:
        # Show tensor names in graph
        for op in graph.get_operations():
            # 获取节点名称
            print(op.values())
        # 获取输入输出节点
        output_tensor = graph.get_tensor_by_name('div_3:0')
        input_tensor = graph.get_tensor_by_name('inputs:0')

        # dummy_input = np.random.randint(0, 1000, (1, 20), dtype=np.int64)
        query1 = np.array([[4158, 7811, 6653,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]], dtype=np.int64)
        query2 = np.array([[9914, 10859, 6907, 8719, 7098, 
                            8861, 4158, 10785, 6299, 1264, 
                            1612, 10285, 6973, 7811, 0,
                            0, 0, 0, 0, 0]],
                           dtype=np.int64)
        output1 = sess.run(output_tensor, 
                          feed_dict={input_tensor: query1})
        output2 = sess.run(output_tensor,
                          feed_dict={input_tensor: query2})
        
        simi = np.dot(output1[0], output2[0])
        print(simi)

很明显在测试过程中,需要给输入张量喂入输出,获取输出节点数据,这个时候我们在第一步指定的输入、输出节点变量名就派上了用场。input_tensor获取输入节点的张量值就使用了变量"inputs:0",":0"表示该节点的第一个输出,如果该节点只有一个输出就是:0;可以看到,在第一步中我们设定的输出变量名是outputs,但是这里并不是。(个人觉得是有点bug的,所以具体的输出节点名称要从计算图上的节点中去查看。当然如果使用最新版本的onnx-tf这个问题是不存在的,由于我需要的pb文件是tf 1.12版本的,亲测1.13版本以上,没有变量名映射的bug)

注意版本差异和环境

  1. 不同的tensorflow版本和onnx-tf版本需要适配:个人总结tf 1.13以上的版本可以使用最新的onnx-tf 1.7版本;tf 1.12版本可以使用onnx-tf 1.3版本。
  2. 注意模型推理是在cpu还是gpu上:tensorflow模型在cpu上推理和gpu上推理存在差异!torch文件默认的输入数据格式是NxCxHxW,而TensorFlow中的函数在gpu环境上都是支持NxCxHxW格式的。但是部分函数,例如max_pool在cpu 上只支持NxHxHxC的格式。解决办法就是,如果要在cpu上推理,需要禁用gpu: os.environ[“CUDA_VISIBLE_DEVICES”]="-1" ,否则转换会默认按照gpu环境导出;当然你如果没有gpu环境,那就不存在这个问题,自动按照cpu运行环境导出。

各位看客小伙伴,有什么问题或者好的解决办法可以留言评论哦!

  • 14
    点赞
  • 62
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
Pytorch是一个基于Python的深度学习框架,而TensorFlow是由谷歌开发的另一个主要深度学习框架。尽管它们都具有相似的功能和应用领域,但它们的底层结构和语法有一些不同。 要将Pytorch模型换为TensorFlow模型,需要进行以下步骤: 1. 确保数据预处理和模型的加载和保存方法与两个框架兼容。Pytorch使用torchvision库来加载和处理数据,而TensorFlow则使用tf.data.Dataset。可以编写一个通用的数据预处理函数,在换过程中适应两个框架的要求。模型加载和保存方法也有所不同,因此需要检查和调整相应的代码。 2. 对于网络架构的换,可以通过手动编写等效的TensorFlow代码来实现。首先,将Pytorch模型的输入、输出和中间层的形状记录下来。然后,将这些信息用于初始化TensorFlow模型,并按照相同的层次结构和参数进行换。需要注意的是,PytorchTensorFlow的层名称和参数格式可能不同,因此需要进行一些调整。 3. 在模型换过程中,还需要调整损失函数和优化器。PytorchTensorFlow使用不同的损失函数和优化器,因此需要将它们进行等效匹配或手动实现。可以将Pytorch的损失函数换为TensorFlow的等效函数,并使用TensorFlow的优化器进行训练。 4. 进行模型的训练和测试,并根据需要进行微调和优化。在训练和测试过程中,可能需要进行调整以适应TensorFlow框架的要求,例如调整图像的通道顺序或输入的格式。 总的来说,将Pytorch模型换为TensorFlow模型需要一些手动调整和修改,但可以通过适应两个框架的不同要求来实现。需要确保数据预处理、网络架构、损失函数和优化器等方面的兼容性,并在训练和测试过程中进行适当的调整和优化。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值