我用pytorch训练YOLOv4模型,训练时的log想转为tf的.pb以执行后续的工作。在转换过程中踩了不少坑,在此记录。
1. Yolo模型
网上很多转换的方法,都大同小异。但需要注意,yolo训练的log只是参数,没有网络架构,因此需要导入自己的YOLO_body,再把log的参数对应上。
import torch
from nets.yolo4 import YoloBody # 我直接在yolov4的文件夹下新建的文件,所以这里直接导入了
model = YoloBody(3,3) # YoloBody(num_anchors, number_classes) 修改成自己的锚框数和类别数
model.load_state_dict(torch.load('model_data\yolo4_weights.pth')) # 把参数和架构对上
如果你的.pth只是参数,而没有架构,会出现报错如:【TypeError: ‘collections.OrderedDict‘ object is not callable...】【 object has no attribute 'state_dict'】这类的,解决可参考https://blog.csdn.net/xiaoqiaoliushuiCC/article/details/114386432。
如果是你自己训练的yolo模型,记得修改yolo.py中的那些路径和变量,和自己的数据对上,不然也会报错。
2.转ONNX
import torch
import torch.nn as nn
import torch.onnx
import onnx
#from onnx_tf.backend import prepare
import argparse
import os
dummy_input = torch.randn(1, 3, 608, 608, device='cpu')
input_names=['input1']
output_names=['output1']
torch.onnx.export(model, dummy_input, "insight.onnx", verbose=True,
input_names=input_names, output_names=output_names)
若报错:Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same 若报错这个,则改dummy...中的device。这里一般是没啥问题。
3.转tensorflow
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
model = onnx.load('./insight.onnx')
tf_model = prepare(model)
tf_model.export_graph('./insight.pb')
成功的化是一个.pb,而不是文件夹
那么变成文件夹很有可能是版本不对,tensorflow这点还是挺讨厌的。
我转换时候的版本如下:
tensorflow 2.4.2
tensorflow-addons 0.15.0
tensorflow-estimator 2.4.0
tensorflow-gpu 2.5.0
tensorflow-gpu-estimator 2.2.0
onnx 1.8.0
onnx-tf 1.6.0
onnxruntime 1.10.0
onnx-tf的1.6.0版本要去github上下载,百度一下就好
4.其他报错
期间在安装包等过程中遇到的报错。
1.【AttributeError: module 'tensorflow' has no attribute 'gfile'】这个因为tensorflow版本不一样,解决如下:
tf.compat.v1.GraphDef() # -> instead of tf.GraphDef()
tf.compat.v2.io.gfile.GFile() # -> instead of tf.gfile.GFile()
或者最简单的方法就是把tensorflow 2.x降版本,降到1.14(个人感觉最稳的)。但好像onnx需要tensorflow>2.2.0,所以就手动把版本不兼容的代码改了。
2.【ERROR: Could not install packages due to an EnvironmentError: [WinError 5] 拒绝访问】
pip install --user xxxxx
若不行,升级pip。
若还是不行,把包卸载了重新装。