import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from net import NeuralNetworkV2
import os
if __name__ == '__main__':
dummy_input = torch.randn(1,1,28,28,device=device,requires_grad=True)
dict_save_dir = "./"
model_name = "model.pth"
check_point = torch.load(os.path.join(dict_save_dir,model_name))
model = NeuralNetworkV2(num_classes=10).to(device)
model.load_state_dict(check_point)
model.eval()
save_path = os.path.join(dict_save_dir,"model_sw1.onnx")
torch.onnx.export(model,
dummy_input,
f"{save_path}",
export_params=True,
verbose=True,
input_names=["input_1",],
output_names=["output_1",],
dynamic_axes={"input_1": {0:"batch_size"},
"output_1":{0:"batch_size"}})
torch转onnx模型
最新推荐文章于 2024-10-02 01:49:36 发布