把Pytorch模型转为ONNX,并进行精简化,用于嵌入式模型部署
Pytorch2ONNX
注意:多输出需要用 [out1, out2, out3] 列表进行存储
import os
import torch
import torchvision
from torchstat import stat
from torchsummary import summary
from yolo import YOLO
os.chdir('../')
ROOT = os.getcwd()
# 保存的文件名
ONNX_PATH = os.path.join(ROOT, "onnx", "xhh_yolov4_tiny.onnx")
# 权重文件
file_path = "logs/3.2 use_mosaic use_Cosine_lr use_anchors/Epoch600-Total_Loss1.6992-Val_Loss1.7980.pth"
MODEL_PATH = os.path.join(ROOT, file_path)
# [1]获取模型
def getYoloModel():
from nets.yolo4_tiny import YoloBody
# model = YoloBody(len(anchors[0]), len(class_names)).eval()
model = YoloBody(3, 2, xhh_out_list_for_Inference=True).eval()
print('Loading weights into state dict...')
state_dict = torch.load(MODEL_PATH, map_location=