将Pytorch训练的.pth模型转为tensorflow的.pb模型的一些坑与解决方法

我用pytorch训练YOLOv4模型,训练时的log想转为tf的.pb以执行后续的工作。在转换过程中踩了不少坑,在此记录。

1. Yolo模型

网上很多转换的方法,都大同小异。但需要注意,yolo训练的log只是参数,没有网络架构,因此需要导入自己的YOLO_body,再把log的参数对应上。

import torch
from nets.yolo4 import YoloBody # 我直接在yolov4的文件夹下新建的文件,所以这里直接导入了


model = YoloBody(3,3) # YoloBody(num_anchors, number_classes) 修改成自己的锚框数和类别数

model.load_state_dict(torch.load('model_data\yolo4_weights.pth')) # 把参数和架构对上

如果你的.pth只是参数,而没有架构,会出现报错如:【TypeError: ‘collections.OrderedDict‘ object is not callable...】【 object has no attribute 'state_dict'】这类的,解决可参考https://blog.csdn.net/xiaoqiaoliushuiCC/article/details/114386432

如果是你自己训练的yolo模型,记得修改yolo.py中的那些路径和变量,和自己的数据对上,不然也会报错。

2.转ONNX

import torch
import torch.nn as nn
import torch.onnx
import onnx
#from onnx_tf.backend import prepare
import argparse
import os

dummy_input = torch.randn(1, 3, 608, 608, device='cpu')

input_names=['input1']
output_names=['output1']

torch.onnx.export(model, dummy_input, "insight.onnx", verbose=True,     
                  input_names=input_names, output_names=output_names)

若报错:Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same 若报错这个,则改dummy...中的device。这里一般是没啥问题。

3.转tensorflow

import onnx
from onnx_tf.backend import prepare
import tensorflow as tf

model = onnx.load('./insight.onnx')
tf_model = prepare(model)
tf_model.export_graph('./insight.pb')

成功的化是一个.pb,而不是文件夹

那么变成文件夹很有可能是版本不对,tensorflow这点还是挺讨厌的。

我转换时候的版本如下:

tensorflow                    2.4.2
tensorflow-addons             0.15.0
tensorflow-estimator          2.4.0
tensorflow-gpu                2.5.0
tensorflow-gpu-estimator      2.2.0

onnx                          1.8.0
onnx-tf                       1.6.0
onnxruntime                   1.10.0

onnx-tf的1.6.0版本要去github上下载,百度一下就好

4.其他报错

期间在安装包等过程中遇到的报错。

1.【AttributeError: module 'tensorflow' has no attribute 'gfile'】这个因为tensorflow版本不一样,解决如下:

tf.compat.v1.GraphDef()   # -> instead of tf.GraphDef()
tf.compat.v2.io.gfile.GFile()   # -> instead of tf.gfile.GFile()

或者最简单的方法就是把tensorflow 2.x降版本,降到1.14(个人感觉最稳的)。但好像onnx需要tensorflow>2.2.0,所以就手动把版本不兼容的代码改了。

2.【ERROR: Could not install packages due to an EnvironmentError: [WinError 5] 拒绝访问】

pip install --user xxxxx 

若不行,升级pip。

若还是不行,把包卸载了重新装。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值