import torch
import torchvision
# 加载并训练PyTorch模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# 创建一个示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)
# 导出模型为ONNX格式
onnx_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_path, export_params=True, opset_version=11)
print("ONNX模型已导出:", onnx_path)
pytorch模型导出为onnx模型
最新推荐文章于 2024-07-22 14:33:33 发布