1 文件和环境准备
使用的是 霹雳吧啦 大神的GitHub代码:swin_transformer
推荐使用torch1.10以上的版本
2 Error_1
Exporting the operator roll to ONNX opset version 11 is not supported.(现在已经支持了)
错误原因:roll算子不支持
解决方案:将model.py中的 roll 修改为 cat,代码如下:
# reverse cyclic shift
if self.shift_size > 0:
# x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
x = torch.cat((shifted_x[:,-self.shift_size:,:,:], shifted_x[:,:-self.shift_size,:,:]), dim=1)
x = torch.cat((shifted_x[:,:,-self.shift_size:,:], shifted_x[:,:,:-self.shift_size,:]), dim=2)
# cyclic shift
if self.shift_size > 0:
# shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
shifted_x = torch.cat((x[:,self.shift_size:,:,:], x[:,:self.shift_size,:,:]), dim=1)
shifted_x = torch.cat((x[:,:,self.shift_size:,:], x[:,:,:self.shift_size,:]), dim=2)
3 Error_2
File "/swin_transformer/model.py", line 365, in forward
x = torch.cat((shifted_x[:,-self.shift_size:,:,:], shifted_x[:,:-self.shift_size,:,:]), dims=1)
TypeError: cat() received an invalid combination of arguments - got (tuple, dims=int), but expected one of:
* (tuple of Tensors tensors, int dim, *, Tensor out)
错误原因:torch.cat函数使用错误
解决方案:torch.cat中第二个参数是dim,不是dims。
4 Error_3
File "/home/users/env/env1/lib64/python3.6/site-packages/torch/onnx/utils.py", line 890, in _graph_op
torch._C._jit_pass_onnx_node_shape_type_inference(n, _params_dict, opset_version)
RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":520, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied.
错误原因:pytorch版本问题,报错的这个版本是1.9.1
解决方案:pytorch版本升级为1.10.1即可。
5 导出onnx模型、推理一张图片
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import swin_tiny_patch4_window7_224 as create_model
def model_convert_onnx(model, input_shape, output_path):
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1])
input_names = ["input1"]
output_names = ["output1"]
torch.onnx.export(
model,
dummy_input,
output_path,
verbose=True,
keep_initializers_as_inputs=True,
opset_version=11, # 版本通常为10 or 11
input_names=input_names,
output_names=output_names,
)
def main():
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
img_size = 224
data_transform = transforms.Compose(
[transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = "../flower_data/tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=5).to(device)
# load model weights
model_weight_path = "./weights/model-9.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
# # 导出onnx模型的输入尺寸,要和pytorch模型的输入尺寸一致
# input_shape = (224, 224)
# # onnx模型输出到哪里去
# output_path = './weights/swin_transformer.onnx'
# model_convert_onnx(model, input_shape, output_path)
# print("model convert onnx finsh, onnx model location:", output_path)
onnx_path = './weights/swin_transformer.onnx'
#---------------------------------------------------------#
# 使用onnxruntime
#---------------------------------------------------------#
image_data = img.numpy()
import onnxruntime
ort_session = onnxruntime.InferenceSession(onnx_path)
# 注意这儿的 input1 需要和model_convert_onnx()中定义的模型输入名称相同!
ort_inputs = {"input1": image_data}
onnx_outputs = ort_session.run(None, ort_inputs)
output = torch.from_numpy(onnx_outputs[0])
output = torch.squeeze(output).cpu()
predict = torch.softmax(output, dim=0)
print("onnx_predict:", predict)
predict_cla = torch.argmax(predict).numpy()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
print("pt_predict:", predict)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
参考链接
https://github.com/pytorch/pytorch/issues/78348
https://blog.csdn.net/blueblood7/article/details/121034635