1.简略版-python转化代码:
import torch
import torchvision
# use Trace to export onnx model
dummy_input = torch.randn(10, 3, 224, 224, device='cuda') # 定义模型的输入shape
model = torchvision.models.alexnet(pretrained=True).cuda() # if delete cuda(), will generate onnx model with no cuda.
input_names = ['inputs']
output_names = ['outputs']
torch.onnx.export(model, dummy_input, f='alexnet.onnx', verbose=True, input_names=input_names,
output_names=output_names, opset_version=10) # generate onnx model of 244M
生成的onnx结构:
2.完整版
import numpy as np
import onnx
import onnxruntime
from torch import nn
import torch.nn.init as init
import torch.utils.model_zoo as model_zoo
import torch.onnx
# super resolution
# https://arxiv.org/abs/1609.05158
"""
save pytorch model to torch.onnx model
and verify the model by using onnx.
"""
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), 1, 2)
self.conv2 = nn.Conv2d(64, 64, (3, 3), 1, 1)
self.conv3 = nn.Conv2d(64, 32, (3, 3), 1, 1)
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), 1, 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, gain=init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, gain=init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, gain=init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
def to_numpy(tensor_val):
return tensor_val.detach().cpu().numpy() if tensor_val.requires_grad else tensor_val.cpu().numpy()
if __name__ == '__main__':
# Create the super-resolution model by using the above model definition
torch_model = SuperResolutionNet(upscale_factor=3)
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch