1、creat_onnx.py
import torch
from torchvision.models import mobilenet_v2
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = mobilenet_v2()
model.load_state_dict(torch.load("./mobilenet_v2-b0353104.pth"),strict=False)
model.to(device)
model.eval()
inputs = torch.ones((1,3,224,224)).to(device)
onnxpath = "./mobilenet_v2.onnx"
torch.onnx.export(model,
inputs,
onnxpath,
export_params=True,
verbose=False,
input_names=['input'],
output_names=['output'],
opset_version=12)
2、pytorch_predict.py
# sample execution (requires torchvision)
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models import mobilenet_v2
import numpy as np
import onnxruntime
filenames = ["ILSVRC2012_val_00000001_n01751748",
"ILSVRC2012_val_00000003_n02105855",
"ILSVRC2012_val_00000095_n02487347",
"ILSVRC2012_val_00000211_n02123597",
"ILSVRC2012_val_00000290_n04328186"]
for filename in filenames:
input_image = Image.open("./"+filename+".JPEG")
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
input_batch = input_batch.to('cpu')
# -------pytorch_predict-------
model = mobilenet_v2()
model.load_state_dict(torch.load("./mobilenet_v2-b0353104.pth"),strict=False)
model.eval()
model.to('cpu')
with torch.no_grad():
output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
output[0].numpy().tofile("./pytorch_result/"+filename+".bin")
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Read the categories
with open("hub/imagenet_classes.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 1)
for i in range(top5_prob.size(0)):
print(categories[top5_catid[i]], top5_prob[i].item())
# --------onnxruntime_predict--------
output_path = 'mobilenet_v2.onnx'
ort_session = onnxruntime.InferenceSession(output_path, providers=['CPUExecutionProvider'])
ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
ort_outs[0].tofile("./onnxruntime_result/"+filename+".bin")