ONNX模型
文章目录
一、torch.onnx.export()详细介绍
1.torch.onnx.export()
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None)
2. 功能:
将pth模型转为onnx文件导出。
3.参数
model (torch.nn.Module) :pth模型文件;
args (tuple of arguments) :模型的输入, 模型的尺寸;
export_params (bool, default True) – 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False;
verbose (bool, default False) :导出轨迹的调试描述;
training (bool, default False) :在训练模式下导出模型。目前,ONNX导出的模型只是为了做推断,通常不需要将其设置为True;
input_names (list of strings, default empty list) :onnx文件的输入名称;
output_names (list of strings, default empty list) :onnx文件的输出名称;
opset_version:默认为9;
dynamic_axes – {‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 : ‘batch_size’}}) 。
二、pth的保存方式
torch.save(model,'save_path')
torch.save(model,path) 会将model的参数、框架都保存到路径path中,但是在加载model的时候可能会因为包版本的不同报错,所以当保存所有模型参数时,需要将模型构造相关代码文件放在相同路径,否则在load的时候无法索引到model的框架。
torch.save(model.state_dict(),model_path)
建议:使用state_dict()模式保存model,torch.save(model.state_dict(),path),这样保存为字典模式,可以直接load。
三、pth转onnx代码
1.使用torch.save(model,‘save_path’)方式保存
x = torch.randn(1, 3, 224, 224, device=device)
输入测试数据 数据格式[batch, channl, height, width]
model.eval()
不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层影响结果。
注:一定要写上这句话,不然可能会影响onnx的输出结果,经验所知。
import torch
import torch.nn
import onnx
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = torch.load('***.pth', map_location=device)
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1, 3, 224, 224, device=device)
torch.onnx.export(model, x, '***.onnx', input_names=input_names, output_names=output_names, verbose='True')
2.使用torch.save(model.state_dict(),model_path)方式保存
该方式保存需要提供网络结构文件。
import torch.onnx
import onnxruntime as ort
from model import Net
# 创建.pth模型
model = Net
# 加载权重
model_path = '***.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_statedict = torch.load(model_path, map_location=device)
model.load_state_dict(model_statedict)
model.to(device)
model.eval()
input_data = torch.randn(1, 3, 224, 224, device=device)
# 转化为onnx模型
input_names = ['input']
output_names = ['output']
torch.onnx.export(model, input_data, '***.onnx', opset_version=9, verbose=True, input_names=input_names, output_names = output_names)
四、resnet18下载与保存,转换为ONNX模型,导出 .wts 格式的权重文件
1.download and save to ‘resnet18.pth’ file:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
def main():
print('cuda device count: ', torch.cuda.device_count())
net = torchvision.models.resnet18(pretrained=True)
#net.fc = nn.Linear(512, 2)
net = net.to('cuda:0')
net.eval()
print(net)
tmp = torch.ones(2, 3, 224, 224).to('cuda:0')
out = net(tmp)
print('resnet18 out:', out.shape)
torch.save(net, "resnet18.pth")
if __name__ == '__main__':
main()
this ‘resnet18.pth’ file contains the model structure and weights.
2.load the .pth file and transform it to ONNX format:
import torch
def main():
model = torch.load('resnet18.pth')
# model.eval()
inputs = torch.randn(1,3,224,224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = inputs.to(device)
torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2)
if __name__ == '__main__':
main()
3.load and read the .pth file, extract the weights of the model to a .wts file
import torch
from torch import nn
import torchvision
import os
import struct
from torchsummary import summary
def main():
print('cuda device count: ', torch.cuda.device_count())
net = torch.load('resnet18.pth')
net = net.to('cuda:0')
net.eval()
print('model: ', net)
#print('state dict: ', net.state_dict().keys())
tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
print('input: ', tmp)
out = net(tmp)
print('output:', out)
summary(net, (3,224,224))
#return
f = open("resnet18.wts", 'w')
f.write("{}\n".format(len(net.state_dict().keys())))
for k,v in net.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")
if __name__ == '__main__':
main()