import torch
import torchvision
import netron
class Classifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone = torchvision.models.resnet34(pretrained=False)
self.backbone.load_state_dict(torch.load("resnet34.pth", map_location=None))
def forward(self, x):
feature = self.backbone(x)
probability = torch.softmax(feature, dim=1)
return probability
dummy = torch.zeros(1, 3, 224, 224)
model = Classifier().eval()
with torch.no_grad():
model(dummy)
torch.onnx.export(
model, dummy,
"classifier.onnx",
input_names=["image"],
output_names=["prob"],
dynamic_axes={"image": {0: "batch"}, "prob": {0: "batch"}},
)
netron.start("classifier.onnx")
注意事项:保证权重pth文件已经下载,下载pth文件一直失败(难受)torchvision.models加载其他模型亦是类似。强烈建议自己搭建几个模型,有助于理解模型细节,batch维度设置成动态的,有助于多线程推理。
1)netron.start("classifier.onnx")可视化onnx模型结构